Skip to main content

mermaid_cli/models/
types.rs

1use crate::agents::ActionDisplay;
2use rustc_hash::FxHashMap;
3use serde::{Deserialize, Serialize};
4use std::sync::Arc;
5
6/// Represents a chat message
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ChatMessage {
9    pub role: MessageRole,
10    pub content: String,
11    pub timestamp: chrono::DateTime<chrono::Local>,
12    /// Actions performed during this message (for display purposes)
13    #[serde(default)]
14    pub actions: Vec<ActionDisplay>,
15    /// Thinking/reasoning content (for models that expose their thought process)
16    #[serde(default)]
17    pub thinking: Option<String>,
18    /// Base64-encoded images/PDFs for multimodal models
19    #[serde(default)]
20    pub images: Option<Vec<String>>,
21    /// Tool calls from the model (Ollama native function calling)
22    #[serde(default)]
23    pub tool_calls: Option<Vec<crate::models::tool_call::ToolCall>>,
24    /// Tool call ID for tool result messages (OpenAI-compatible format)
25    /// This links the tool result back to the original tool_call from the assistant
26    #[serde(default)]
27    pub tool_call_id: Option<String>,
28    /// Tool name for tool result messages (required by Ollama API)
29    /// This tells the model which function's result is being returned
30    #[serde(default)]
31    pub tool_name: Option<String>,
32}
33
34impl ChatMessage {
35    /// Extract thinking blocks from message content
36    /// Returns (thinking_content, answer_content)
37    pub fn extract_thinking(text: &str) -> (Option<String>, String) {
38        // Check if the text contains thinking blocks
39        if !text.contains("Thinking...") {
40            return (None, text.to_string());
41        }
42
43        // Find thinking block boundaries
44        if let Some(thinking_start) = text.find("Thinking...") {
45            if let Some(thinking_end) = text.find("...done thinking.") {
46                // Extract thinking content (everything between markers)
47                let thinking_content_start = thinking_start + "Thinking...".len();
48                let thinking_text = text[thinking_content_start..thinking_end].trim().to_string();
49
50                // Extract answer (everything after thinking block)
51                let answer_start = thinking_end + "...done thinking.".len();
52                let answer_text = text[answer_start..].trim().to_string();
53
54                return (Some(thinking_text), answer_text);
55            }
56        }
57
58        // If we found "Thinking..." but not the end marker, treat it all as thinking in progress
59        if let Some(thinking_start) = text.find("Thinking...") {
60            let thinking_content_start = thinking_start + "Thinking...".len();
61            let thinking_text = text[thinking_content_start..].trim().to_string();
62            return (Some(thinking_text), String::new());
63        }
64
65        (None, text.to_string())
66    }
67}
68
69#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
70pub enum MessageRole {
71    User,
72    Assistant,
73    System,
74    /// Tool result message (OpenAI-compatible format for function calling)
75    Tool,
76}
77
78/// Represents the context of the current project
79#[derive(Debug, Clone)]
80pub struct ProjectContext {
81    /// Root directory of the project
82    pub root_path: String,
83    /// Map of file paths to their contents (using FxHashMap for performance)
84    pub files: FxHashMap<String, String>,
85    /// Total token count of the context
86    pub token_count: usize,
87    /// Files to explicitly include in context
88    pub included_files: Vec<String>,
89}
90
91impl ProjectContext {
92    pub fn new(root_path: String) -> Self {
93        Self {
94            root_path,
95            files: FxHashMap::default(),
96            token_count: 0,
97            included_files: Vec::new(),
98        }
99    }
100
101    /// Add a file to the context
102    pub fn add_file(&mut self, path: String, content: String) {
103        self.files.insert(path, content);
104    }
105
106    /// Get a formatted string of the project context for the model
107    pub fn to_prompt_context(&self) -> String {
108        // Pre-calculate capacity to reduce allocations
109        let header_size = 100; // Approximate size of headers
110        let file_list_size = self.files.keys().map(|k| k.len() + 5).sum::<usize>(); // "  - path\n"
111        let content_size: usize = self.included_files.iter()
112            .filter_map(|path| self.files.get(path).map(|content| (path, content)))
113            .map(|(path, content)| content.len() + path.len() + 20) // path + decorators
114            .sum();
115
116        let capacity = header_size + file_list_size + content_size;
117        let mut context = String::with_capacity(capacity);
118
119        context.push_str("Project root: ");
120        context.push_str(&self.root_path);
121        context.push_str("\nFiles in context: ");
122        context.push_str(&self.files.len().to_string());
123        context.push_str("\n\n");
124
125        // Add file tree structure
126        context.push_str("Project structure:\n");
127        for path in self.files.keys() {
128            context.push_str("  - ");
129            context.push_str(path);
130            context.push('\n');
131        }
132        context.push('\n');
133
134        // Add explicitly included files
135        if !self.included_files.is_empty() {
136            context.push_str("Relevant file contents:\n");
137            for file_path in &self.included_files {
138                if let Some(content) = self.files.get(file_path) {
139                    context.push_str("\n=== ");
140                    context.push_str(file_path);
141                    context.push_str(" ===\n");
142                    context.push_str(content);
143                    context.push_str("\n=== end ===\n");
144                }
145            }
146        }
147
148        context
149    }
150}
151
152/// Response from a model
153#[derive(Debug, Clone)]
154pub struct ModelResponse {
155    /// The actual response text
156    pub content: String,
157    /// Usage statistics if available
158    pub usage: Option<TokenUsage>,
159    /// Model that generated the response
160    pub model_name: String,
161    /// Thinking/reasoning content (for models that expose their thought process)
162    pub thinking: Option<String>,
163    /// Tool calls from the model (Ollama native function calling)
164    pub tool_calls: Option<Vec<crate::models::tool_call::ToolCall>>,
165}
166
167/// Token usage statistics
168#[derive(Debug, Clone)]
169pub struct TokenUsage {
170    pub prompt_tokens: usize,
171    pub completion_tokens: usize,
172    pub total_tokens: usize,
173}
174
175/// Stream callback type for real-time response streaming
176pub type StreamCallback = Arc<dyn Fn(&str) + Send + Sync>;
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    // Phase 3 Test Suite: Model Types - 8 comprehensive tests
183
184    #[test]
185    fn test_message_role_equality() {
186        let user1 = MessageRole::User;
187        let user2 = MessageRole::User;
188        let assistant = MessageRole::Assistant;
189
190        assert_eq!(user1, user2, "User roles should be equal");
191        assert_ne!(user1, assistant, "Different roles should not be equal");
192    }
193
194    #[test]
195    fn test_chat_message_creation() {
196        let message = ChatMessage {
197            role: MessageRole::User,
198            content: "Hello, assistant!".to_string(),
199            timestamp: chrono::Local::now(),
200            actions: vec![],
201            thinking: None,
202            images: None,
203            tool_calls: None,
204            tool_call_id: None,
205            tool_name: None,
206        };
207
208        assert_eq!(message.role, MessageRole::User);
209        assert_eq!(message.content, "Hello, assistant!");
210        assert!(message.actions.is_empty());
211        assert!(message.thinking.is_none());
212        assert!(message.images.is_none());
213        assert!(message.tool_calls.is_none());
214        assert!(message.tool_call_id.is_none());
215        assert!(message.tool_name.is_none());
216    }
217
218    #[test]
219    fn test_project_context_new() {
220        let context = ProjectContext::new("/home/user/project".to_string());
221
222        assert_eq!(context.root_path, "/home/user/project");
223        assert!(context.files.is_empty());
224        assert_eq!(context.token_count, 0);
225        assert!(context.included_files.is_empty());
226    }
227
228    #[test]
229    fn test_project_context_add_file() {
230        let mut context = ProjectContext::new("/project".to_string());
231
232        context.add_file("src/main.rs".to_string(), "fn main() {}".to_string());
233        context.add_file("Cargo.toml".to_string(), "[package]".to_string());
234
235        assert_eq!(context.files.len(), 2);
236        assert_eq!(
237            context.files.get("src/main.rs"),
238            Some(&"fn main() {}".to_string())
239        );
240        assert_eq!(
241            context.files.get("Cargo.toml"),
242            Some(&"[package]".to_string())
243        );
244    }
245
246    #[test]
247    fn test_project_context_prompt_formatting() {
248        let mut context = ProjectContext::new("/project".to_string());
249        context.add_file("src/main.rs".to_string(), "fn main() {}".to_string());
250        context.add_file("Cargo.toml".to_string(), "[package]".to_string());
251        context.included_files = vec!["src/main.rs".to_string()];
252
253        let prompt = context.to_prompt_context();
254
255        assert!(
256            prompt.contains("Project root: /project"),
257            "Should include project root"
258        );
259        assert!(
260            prompt.contains("Files in context: 2"),
261            "Should include file count"
262        );
263        assert!(
264            prompt.contains("src/main.rs"),
265            "Should include file structure"
266        );
267        assert!(
268            prompt.contains("Cargo.toml"),
269            "Should include file structure"
270        );
271        assert!(
272            prompt.contains("fn main() {}"),
273            "Should include file content"
274        );
275        // Check that included files section exists
276        assert!(
277            prompt.contains("Relevant file contents") || prompt.contains("==="),
278            "Should include section for relevant files"
279        );
280    }
281
282    #[test]
283    fn test_token_usage_structure() {
284        let usage = TokenUsage {
285            prompt_tokens: 100,
286            completion_tokens: 50,
287            total_tokens: 150,
288        };
289
290        assert_eq!(usage.prompt_tokens, 100);
291        assert_eq!(usage.completion_tokens, 50);
292        assert_eq!(usage.total_tokens, 150);
293    }
294
295    #[test]
296    fn test_model_response_creation() {
297        let usage = TokenUsage {
298            prompt_tokens: 100,
299            completion_tokens: 50,
300            total_tokens: 150,
301        };
302
303        let response = ModelResponse {
304            content: "Hello, world!".to_string(),
305            usage: Some(usage),
306            model_name: "ollama/tinyllama".to_string(),
307            thinking: None,
308            tool_calls: None,
309        };
310
311        assert_eq!(response.content, "Hello, world!");
312        assert!(response.usage.is_some());
313        assert_eq!(response.model_name, "ollama/tinyllama");
314        assert_eq!(response.usage.unwrap().total_tokens, 150);
315        assert!(response.tool_calls.is_none());
316    }
317}