1use crate::agents::ActionDisplay;
2use rustc_hash::FxHashMap;
3use serde::{Deserialize, Serialize};
4use std::sync::Arc;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ChatMessage {
9 pub role: MessageRole,
10 pub content: String,
11 pub timestamp: chrono::DateTime<chrono::Local>,
12 #[serde(default)]
14 pub actions: Vec<ActionDisplay>,
15 #[serde(default)]
17 pub thinking: Option<String>,
18 #[serde(default)]
20 pub images: Option<Vec<String>>,
21 #[serde(default)]
23 pub tool_calls: Option<Vec<crate::models::tool_call::ToolCall>>,
24}
25
26impl ChatMessage {
27 pub fn extract_thinking(text: &str) -> (Option<String>, String) {
30 if !text.contains("Thinking...") {
32 return (None, text.to_string());
33 }
34
35 if let Some(thinking_start) = text.find("Thinking...") {
37 if let Some(thinking_end) = text.find("...done thinking.") {
38 let thinking_content_start = thinking_start + "Thinking...".len();
40 let thinking_text = text[thinking_content_start..thinking_end].trim().to_string();
41
42 let answer_start = thinking_end + "...done thinking.".len();
44 let answer_text = text[answer_start..].trim().to_string();
45
46 return (Some(thinking_text), answer_text);
47 }
48 }
49
50 if let Some(thinking_start) = text.find("Thinking...") {
52 let thinking_content_start = thinking_start + "Thinking...".len();
53 let thinking_text = text[thinking_content_start..].trim().to_string();
54 return (Some(thinking_text), String::new());
55 }
56
57 (None, text.to_string())
58 }
59}
60
61#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
62pub enum MessageRole {
63 User,
64 Assistant,
65 System,
66}
67
68#[derive(Debug, Clone)]
70pub struct ProjectContext {
71 pub root_path: String,
73 pub files: FxHashMap<String, String>,
75 pub project_type: Option<String>,
77 pub token_count: usize,
79 pub included_files: Vec<String>,
81}
82
83impl ProjectContext {
84 pub fn new(root_path: String) -> Self {
85 Self {
86 root_path,
87 files: FxHashMap::default(),
88 project_type: None,
89 token_count: 0,
90 included_files: Vec::new(),
91 }
92 }
93
94 pub fn add_file(&mut self, path: String, content: String) {
96 self.files.insert(path, content);
97 }
98
99 pub fn to_prompt_context(&self) -> String {
101 let header_size = 100; let file_list_size = self.files.keys().map(|k| k.len() + 5).sum::<usize>(); let content_size: usize = self.included_files.iter()
105 .filter_map(|path| self.files.get(path).map(|content| (path, content)))
106 .map(|(path, content)| content.len() + path.len() + 20) .sum();
108
109 let capacity = header_size + file_list_size + content_size;
110 let mut context = String::with_capacity(capacity);
111
112 if let Some(project_type) = &self.project_type {
113 context.push_str("Project type: ");
114 context.push_str(project_type);
115 context.push('\n');
116 }
117
118 context.push_str("Project root: ");
119 context.push_str(&self.root_path);
120 context.push_str("\nFiles in context: ");
121 context.push_str(&self.files.len().to_string());
122 context.push_str("\n\n");
123
124 context.push_str("Project structure:\n");
126 for path in self.files.keys() {
127 context.push_str(" - ");
128 context.push_str(path);
129 context.push('\n');
130 }
131 context.push('\n');
132
133 if !self.included_files.is_empty() {
135 context.push_str("Relevant file contents:\n");
136 for file_path in &self.included_files {
137 if let Some(content) = self.files.get(file_path) {
138 context.push_str("\n=== ");
139 context.push_str(file_path);
140 context.push_str(" ===\n");
141 context.push_str(content);
142 context.push_str("\n=== end ===\n");
143 }
144 }
145 }
146
147 context
148 }
149}
150
151#[derive(Debug, Clone)]
153pub struct ModelResponse {
154 pub content: String,
156 pub usage: Option<TokenUsage>,
158 pub model_name: String,
160 pub thinking: Option<String>,
162 pub tool_calls: Option<Vec<crate::models::tool_call::ToolCall>>,
164}
165
166#[derive(Debug, Clone)]
168pub struct TokenUsage {
169 pub prompt_tokens: usize,
170 pub completion_tokens: usize,
171 pub total_tokens: usize,
172}
173
174pub type StreamCallback = Arc<dyn Fn(&str) + Send + Sync>;
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 #[test]
184 fn test_message_role_equality() {
185 let user1 = MessageRole::User;
186 let user2 = MessageRole::User;
187 let assistant = MessageRole::Assistant;
188
189 assert_eq!(user1, user2, "User roles should be equal");
190 assert_ne!(user1, assistant, "Different roles should not be equal");
191 }
192
193 #[test]
194 fn test_chat_message_creation() {
195 let message = ChatMessage {
196 role: MessageRole::User,
197 content: "Hello, assistant!".to_string(),
198 timestamp: chrono::Local::now(),
199 actions: vec![],
200 thinking: None,
201 images: None,
202 tool_calls: None,
203 };
204
205 assert_eq!(message.role, MessageRole::User);
206 assert_eq!(message.content, "Hello, assistant!");
207 assert!(message.actions.is_empty());
208 assert!(message.thinking.is_none());
209 assert!(message.images.is_none());
210 assert!(message.tool_calls.is_none());
211 }
212
213 #[test]
214 fn test_project_context_new() {
215 let context = ProjectContext::new("/home/user/project".to_string());
216
217 assert_eq!(context.root_path, "/home/user/project");
218 assert!(context.files.is_empty());
219 assert_eq!(context.token_count, 0);
220 assert!(context.included_files.is_empty());
221 assert_eq!(context.project_type, None);
222 }
223
224 #[test]
225 fn test_project_context_add_file() {
226 let mut context = ProjectContext::new("/project".to_string());
227
228 context.add_file("src/main.rs".to_string(), "fn main() {}".to_string());
229 context.add_file("Cargo.toml".to_string(), "[package]".to_string());
230
231 assert_eq!(context.files.len(), 2);
232 assert_eq!(
233 context.files.get("src/main.rs"),
234 Some(&"fn main() {}".to_string())
235 );
236 assert_eq!(
237 context.files.get("Cargo.toml"),
238 Some(&"[package]".to_string())
239 );
240 }
241
242 #[test]
243 fn test_project_context_prompt_formatting() {
244 let mut context = ProjectContext::new("/project".to_string());
245 context.project_type = Some("rust".to_string());
246 context.add_file("src/main.rs".to_string(), "fn main() {}".to_string());
247 context.add_file("Cargo.toml".to_string(), "[package]".to_string());
248 context.included_files = vec!["src/main.rs".to_string()];
249
250 let prompt = context.to_prompt_context();
251
252 assert!(
253 prompt.contains("Project type: rust"),
254 "Should include project type"
255 );
256 assert!(
257 prompt.contains("Project root: /project"),
258 "Should include project root"
259 );
260 assert!(
261 prompt.contains("Files in context: 2"),
262 "Should include file count"
263 );
264 assert!(
265 prompt.contains("src/main.rs"),
266 "Should include file structure"
267 );
268 assert!(
269 prompt.contains("Cargo.toml"),
270 "Should include file structure"
271 );
272 assert!(
273 prompt.contains("fn main() {}"),
274 "Should include file content"
275 );
276 assert!(
278 prompt.contains("Relevant file contents") || prompt.contains("==="),
279 "Should include section for relevant files"
280 );
281 }
282
283 #[test]
284 fn test_token_usage_structure() {
285 let usage = TokenUsage {
286 prompt_tokens: 100,
287 completion_tokens: 50,
288 total_tokens: 150,
289 };
290
291 assert_eq!(usage.prompt_tokens, 100);
292 assert_eq!(usage.completion_tokens, 50);
293 assert_eq!(usage.total_tokens, 150);
294 }
295
296 #[test]
297 fn test_model_response_creation() {
298 let usage = TokenUsage {
299 prompt_tokens: 100,
300 completion_tokens: 50,
301 total_tokens: 150,
302 };
303
304 let response = ModelResponse {
305 content: "Hello, world!".to_string(),
306 usage: Some(usage),
307 model_name: "ollama/tinyllama".to_string(),
308 thinking: None,
309 };
310
311 assert_eq!(response.content, "Hello, world!");
312 assert!(response.usage.is_some());
313 assert_eq!(response.model_name, "ollama/tinyllama");
314 assert_eq!(response.usage.unwrap().total_tokens, 150);
315 }
316}