git_commit_sage/
protocol.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::fmt::Debug;
4
5/// Represents a message in a conversation
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct Message {
8    pub role: String,
9    pub content: String,
10}
11
12/// Configuration for model generation
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct GenerationConfig {
15    pub temperature: f32,
16    pub max_tokens: u32,
17    pub stop_sequences: Vec<String>,
18}
19
20/// Context for model interaction
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ModelContext {
23    pub messages: Vec<Message>,
24    pub config: GenerationConfig,
25}
26
27/// Trait for model providers (e.g., Together.ai, OpenAI, local models)
28#[async_trait]
29pub trait ModelProvider: Send + Sync {
30    /// The error type returned by this provider
31    type Error: std::error::Error + Send + Sync + 'static;
32
33    /// Generate a response using the provided context
34    async fn generate(&self, context: ModelContext) -> Result<String, Self::Error>;
35
36    /// Get the model identifier
37    fn model_id(&self) -> &str;
38
39    /// Get the default configuration for this model
40    fn default_config(&self) -> GenerationConfig;
41}
42
43/// Trait for commit message generators
44#[async_trait]
45pub trait CommitMessageGenerator: Send + Sync {
46    /// The error type returned by this generator
47    type Error: std::error::Error + Send + Sync + 'static;
48
49    /// Generate a commit message from a diff
50    async fn generate_message(&self, diff: &str) -> Result<String, Self::Error>;
51
52    /// Validate a commit message format
53    fn validate_message(&self, message: &str) -> bool;
54}
55
56/// Implementation of CommitMessageGenerator for any ModelProvider
57#[async_trait]
58impl<T: ModelProvider> CommitMessageGenerator for T {
59    type Error = T::Error;
60
61    async fn generate_message(&self, diff: &str) -> Result<String, Self::Error> {
62        let context = ModelContext {
63            messages: vec![
64                Message {
65                    role: "system".to_string(),
66                    content: "You are a highly skilled developer who writes perfect conventional commit messages. \
67                        You analyze git diffs and generate commit messages following the Conventional Commits specification. \
68                        Your messages should be descriptive and precise, following this format:\n\
69                        - For small changes: type(scope): concise description\n\
70                        - For large changes (>5 files or >100 lines): type(scope): comprehensive description of main changes\n\
71                        The type must be one of: feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert.\n\
72                        The scope should reflect the main component being changed.\n\
73                        The description should be clear, precise, and written in imperative mood.\n\
74                        For large changes, ensure the description captures the major components being modified.".to_string(),
75                },
76                Message {
77                    role: "user".to_string(),
78                    content: format!(
79                        "Generate a conventional commit message for the following git diff.\n\
80                        The message must follow the conventional commit format.\n\
81                        If the diff is large (>5 files or >100 lines), make the description more comprehensive.\n\
82                        Only return the commit message, nothing else.\n\n\
83                        Diff:\n{}", 
84                        diff
85                    ),
86                },
87            ],
88            config: self.default_config(),
89        };
90
91        self.generate(context).await
92    }
93
94    fn validate_message(&self, message: &str) -> bool {
95        // Basic format: <type>[optional scope]: <description>
96        let parts: Vec<&str> = message.splitn(2, ": ").collect();
97        if parts.len() != 2 {
98            return false;
99        }
100
101        let type_part = parts[0];
102        let commit_type = if type_part.contains('(') {
103            type_part.split('(').next().unwrap_or("")
104        } else {
105            type_part
106        };
107
108        let valid_types = [
109            "feat", "fix", "docs", "style", "refactor",
110            "perf", "test", "build", "ci", "chore", "revert"
111        ];
112
113        valid_types.contains(&commit_type)
114    }
115}
116
117/// Together.ai implementation of ModelProvider
118pub struct TogetherAiProvider {
119    api_key: String,
120    model: String,
121    client: reqwest::Client,
122}
123
124#[async_trait]
125impl ModelProvider for TogetherAiProvider {
126    type Error = crate::Error;
127
128    async fn generate(&self, context: ModelContext) -> Result<String, Self::Error> {
129        let request = serde_json::json!({
130            "model": self.model,
131            "messages": context.messages,
132            "temperature": context.config.temperature,
133            "max_tokens": context.config.max_tokens,
134            "stop": context.config.stop_sequences,
135        });
136
137        let response = self.client
138            .post("https://api.together.xyz/v1/chat/completions")
139            .header("Authorization", format!("Bearer {}", self.api_key))
140            .json(&request)
141            .send()
142            .await?
143            .error_for_status()?
144            .json::<serde_json::Value>()
145            .await?;
146
147        response["choices"][0]["message"]["content"]
148            .as_str()
149            .map(|s| s.trim().to_string())
150            .ok_or_else(|| crate::Error::CommitMessageGeneration("No response from API".to_string()))
151    }
152
153    fn model_id(&self) -> &str {
154        &self.model
155    }
156
157    fn default_config(&self) -> GenerationConfig {
158        GenerationConfig {
159            temperature: 0.3,
160            max_tokens: 100,
161            stop_sequences: vec!["\n".to_string()],
162        }
163    }
164}
165
166impl TogetherAiProvider {
167    pub fn new(api_key: String, model: String) -> Self {
168        Self {
169            api_key,
170            model,
171            client: reqwest::Client::new(),
172        }
173    }
174}