git_commit_sage/
protocol.rs1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::fmt::Debug;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct Message {
8 pub role: String,
9 pub content: String,
10}
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct GenerationConfig {
15 pub temperature: f32,
16 pub max_tokens: u32,
17 pub stop_sequences: Vec<String>,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ModelContext {
23 pub messages: Vec<Message>,
24 pub config: GenerationConfig,
25}
26
27#[async_trait]
29pub trait ModelProvider: Send + Sync {
30 type Error: std::error::Error + Send + Sync + 'static;
32
33 async fn generate(&self, context: ModelContext) -> Result<String, Self::Error>;
35
36 fn model_id(&self) -> &str;
38
39 fn default_config(&self) -> GenerationConfig;
41}
42
43#[async_trait]
45pub trait CommitMessageGenerator: Send + Sync {
46 type Error: std::error::Error + Send + Sync + 'static;
48
49 async fn generate_message(&self, diff: &str) -> Result<String, Self::Error>;
51
52 fn validate_message(&self, message: &str) -> bool;
54}
55
56#[async_trait]
58impl<T: ModelProvider> CommitMessageGenerator for T {
59 type Error = T::Error;
60
61 async fn generate_message(&self, diff: &str) -> Result<String, Self::Error> {
62 let context = ModelContext {
63 messages: vec![
64 Message {
65 role: "system".to_string(),
66 content: "You are a highly skilled developer who writes perfect conventional commit messages. \
67 You analyze git diffs and generate commit messages following the Conventional Commits specification. \
68 Your messages should be descriptive and precise, following this format:\n\
69 - For small changes: type(scope): concise description\n\
70 - For large changes (>5 files or >100 lines): type(scope): comprehensive description of main changes\n\
71 The type must be one of: feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert.\n\
72 The scope should reflect the main component being changed.\n\
73 The description should be clear, precise, and written in imperative mood.\n\
74 For large changes, ensure the description captures the major components being modified.".to_string(),
75 },
76 Message {
77 role: "user".to_string(),
78 content: format!(
79 "Generate a conventional commit message for the following git diff.\n\
80 The message must follow the conventional commit format.\n\
81 If the diff is large (>5 files or >100 lines), make the description more comprehensive.\n\
82 Only return the commit message, nothing else.\n\n\
83 Diff:\n{}",
84 diff
85 ),
86 },
87 ],
88 config: self.default_config(),
89 };
90
91 self.generate(context).await
92 }
93
94 fn validate_message(&self, message: &str) -> bool {
95 let parts: Vec<&str> = message.splitn(2, ": ").collect();
97 if parts.len() != 2 {
98 return false;
99 }
100
101 let type_part = parts[0];
102 let commit_type = if type_part.contains('(') {
103 type_part.split('(').next().unwrap_or("")
104 } else {
105 type_part
106 };
107
108 let valid_types = [
109 "feat", "fix", "docs", "style", "refactor",
110 "perf", "test", "build", "ci", "chore", "revert"
111 ];
112
113 valid_types.contains(&commit_type)
114 }
115}
116
117pub struct TogetherAiProvider {
119 api_key: String,
120 model: String,
121 client: reqwest::Client,
122}
123
124#[async_trait]
125impl ModelProvider for TogetherAiProvider {
126 type Error = crate::Error;
127
128 async fn generate(&self, context: ModelContext) -> Result<String, Self::Error> {
129 let request = serde_json::json!({
130 "model": self.model,
131 "messages": context.messages,
132 "temperature": context.config.temperature,
133 "max_tokens": context.config.max_tokens,
134 "stop": context.config.stop_sequences,
135 });
136
137 let response = self.client
138 .post("https://api.together.xyz/v1/chat/completions")
139 .header("Authorization", format!("Bearer {}", self.api_key))
140 .json(&request)
141 .send()
142 .await?
143 .error_for_status()?
144 .json::<serde_json::Value>()
145 .await?;
146
147 response["choices"][0]["message"]["content"]
148 .as_str()
149 .map(|s| s.trim().to_string())
150 .ok_or_else(|| crate::Error::CommitMessageGeneration("No response from API".to_string()))
151 }
152
153 fn model_id(&self) -> &str {
154 &self.model
155 }
156
157 fn default_config(&self) -> GenerationConfig {
158 GenerationConfig {
159 temperature: 0.3,
160 max_tokens: 100,
161 stop_sequences: vec!["\n".to_string()],
162 }
163 }
164}
165
166impl TogetherAiProvider {
167 pub fn new(api_key: String, model: String) -> Self {
168 Self {
169 api_key,
170 model,
171 client: reqwest::Client::new(),
172 }
173 }
174}