Skip to main content

mermaid_cli/runtime/
non_interactive.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::sync::Arc;
4use tokio::sync::RwLock;
5
6use crate::utils::MutexExt;
7
8use crate::{
9    agents::{execute_action, ActionResult as AgentActionResult, AgentAction},
10    app::Config,
11    cli::OutputFormat,
12    constants::{DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE},
13    models::{ChatMessage, Model, ModelConfig, ModelFactory},
14    prompts,
15};
16
17/// Result of a non-interactive run
18#[derive(Debug, Serialize, Deserialize)]
19pub struct NonInteractiveResult {
20    /// The prompt that was executed
21    pub prompt: String,
22    /// The model's response
23    pub response: String,
24    /// Actions that were executed (if any)
25    pub actions: Vec<ActionResult>,
26    /// Any errors that occurred
27    pub errors: Vec<String>,
28    /// Metadata about the execution
29    pub metadata: ExecutionMetadata,
30}
31
32#[derive(Debug, Serialize, Deserialize)]
33pub struct ActionResult {
34    /// Type of action (file_write, command, etc.)
35    pub action_type: String,
36    /// Target (file path or command)
37    pub target: String,
38    /// Whether the action was executed successfully
39    pub success: bool,
40    /// Output or error message
41    pub output: Option<String>,
42}
43
44#[derive(Debug, Serialize, Deserialize)]
45pub struct ExecutionMetadata {
46    /// Model used
47    pub model: String,
48    /// Total tokens used
49    pub tokens_used: Option<usize>,
50    /// Execution time in milliseconds
51    pub duration_ms: u128,
52    /// Whether actions were executed
53    pub actions_executed: bool,
54}
55
56/// Non-interactive runner for executing single prompts
57pub struct NonInteractiveRunner {
58    model: Arc<RwLock<Box<dyn Model>>>,
59    no_execute: bool,
60    max_tokens: Option<usize>,
61}
62
63impl NonInteractiveRunner {
64    /// Create a new non-interactive runner
65    pub async fn new(
66        model_id: String,
67        config: Config,
68        no_execute: bool,
69        max_tokens: Option<usize>,
70        backend: Option<&str>,
71    ) -> Result<Self> {
72        // Create model instance with optional backend preference
73        let model = ModelFactory::create_with_backend(&model_id, Some(&config), backend).await?;
74
75        Ok(Self {
76            model: Arc::new(RwLock::new(model)),
77            no_execute,
78            max_tokens,
79        })
80    }
81
82    /// Execute a single prompt and return the result
83    pub async fn execute(&self, prompt: String) -> Result<NonInteractiveResult> {
84        let start_time = std::time::Instant::now();
85        let mut errors = Vec::new();
86        let mut actions = Vec::new();
87
88        // Build messages using the same system prompt as interactive mode
89        let system_message = ChatMessage::system(prompts::get_system_prompt());
90        let user_message = ChatMessage::user(prompt.clone());
91
92        let messages = vec![system_message, user_message];
93
94        // Get model name from the model
95        let model_guard = self.model.read().await;
96        let model_name = model_guard.name().to_string();
97        drop(model_guard);
98
99        // Create model config
100        let model_config = ModelConfig {
101            model: model_name,
102            temperature: DEFAULT_TEMPERATURE,
103            max_tokens: self.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
104            top_p: Some(1.0),
105            frequency_penalty: None,
106            presence_penalty: None,
107            system_prompt: None,
108            thinking_enabled: false, // Non-interactive mode doesn't need thinking
109            backend_options: std::collections::HashMap::new(),
110        };
111
112        // Send prompt to model
113        let full_response;
114        let tokens_used;
115
116        // Create a callback to capture the response
117        let response_text = Arc::new(std::sync::Mutex::new(String::new()));
118        let response_clone = Arc::clone(&response_text);
119        let callback = Arc::new(move |chunk: &str| {
120            let mut resp = response_clone.lock_mut_safe();
121            resp.push_str(chunk);
122        });
123
124        // Call the model
125        let model_name;
126        let result = {
127            let model = self.model.write().await;
128            model_name = model.name().to_string();
129            model
130                .chat(&messages, &model_config, Some(callback))
131                .await
132        };
133
134        // Parse actions from tool calls (Ollama native function calling)
135        let parsed_actions: Vec<AgentAction> = match result {
136            Ok(response) => {
137                // Try to get content from the callback first
138                let callback_content = response_text.lock_mut_safe().clone();
139                if !callback_content.is_empty() {
140                    full_response = callback_content;
141                } else {
142                    full_response = response.content;
143                }
144                tokens_used = response.usage.map(|u| u.total_tokens).unwrap_or(0);
145
146                // Convert tool_calls to AgentActions
147                if let Some(tool_calls) = response.tool_calls {
148                    tool_calls
149                        .iter()
150                        .filter_map(|tc| tc.to_agent_action().ok())
151                        .collect()
152                } else {
153                    vec![]
154                }
155            },
156            Err(e) => {
157                errors.push(format!("Model error: {}", e));
158                full_response = response_text.lock_mut_safe().clone();
159                tokens_used = 0;
160                vec![]
161            },
162        };
163
164        // Process parsed actions
165        for action in parsed_actions {
166            let (action_type, target) = extract_action_info(&action);
167
168            if self.no_execute {
169                actions.push(ActionResult {
170                    action_type,
171                    target,
172                    success: false,
173                    output: Some("Not executed (--no-execute mode)".to_string()),
174                });
175            } else {
176                let result = execute_action(&action).await;
177                let action_result = match result {
178                    AgentActionResult::Success { output } => ActionResult {
179                        action_type,
180                        target,
181                        success: true,
182                        output: Some(output),
183                    },
184                    AgentActionResult::Error { error } => ActionResult {
185                        action_type,
186                        target,
187                        success: false,
188                        output: Some(error),
189                    },
190                };
191                actions.push(action_result);
192            }
193        }
194
195        let duration_ms = start_time.elapsed().as_millis();
196        let actions_executed = !self.no_execute && !actions.is_empty();
197
198        Ok(NonInteractiveResult {
199            prompt,
200            response: full_response,
201            actions,
202            errors,
203            metadata: ExecutionMetadata {
204                model: model_name,
205                tokens_used: Some(tokens_used),
206                duration_ms,
207                actions_executed,
208            },
209        })
210    }
211
212    /// Format the result according to the output format
213    pub fn format_result(&self, result: &NonInteractiveResult, format: OutputFormat) -> String {
214        match format {
215            OutputFormat::Json => serde_json::to_string_pretty(result).unwrap_or_else(|e| {
216                format!("{{\"error\": \"Failed to serialize result: {}\"}}", e)
217            }),
218            OutputFormat::Text => {
219                let mut output = String::new();
220                output.push_str(&result.response);
221
222                if !result.actions.is_empty() {
223                    output.push_str("\n\n--- Actions ---\n");
224                    for action in &result.actions {
225                        output.push_str(&format!(
226                            "[{}] {} - {}\n",
227                            if action.success { "OK" } else { "FAIL" },
228                            action.action_type,
229                            action.target
230                        ));
231                        if let Some(ref out) = action.output {
232                            output.push_str(&format!("  {}\n", out));
233                        }
234                    }
235                }
236
237                if !result.errors.is_empty() {
238                    output.push_str("\n--- Errors ---\n");
239                    for error in &result.errors {
240                        output.push_str(&format!("• {}\n", error));
241                    }
242                }
243
244                output
245            },
246            OutputFormat::Markdown => {
247                let mut output = String::new();
248
249                output.push_str("## Response\n\n");
250                output.push_str(&result.response);
251                output.push_str("\n\n");
252
253                if !result.actions.is_empty() {
254                    output.push_str("## Actions Executed\n\n");
255                    for action in &result.actions {
256                        let status = if action.success { "SUCCESS" } else { "FAILED" };
257                        output.push_str(&format!(
258                            "- {} **{}**: `{}`\n",
259                            status, action.action_type, action.target
260                        ));
261                        if let Some(ref out) = action.output {
262                            output.push_str(&format!("  ```\n  {}\n  ```\n", out));
263                        }
264                    }
265                    output.push('\n');
266                }
267
268                if !result.errors.is_empty() {
269                    output.push_str("## Errors\n\n");
270                    for error in &result.errors {
271                        output.push_str(&format!("- {}\n", error));
272                    }
273                    output.push('\n');
274                }
275
276                output.push_str("---\n");
277                output.push_str(&format!(
278                    "*Model: {} | Tokens: {} | Duration: {}ms*\n",
279                    result.metadata.model,
280                    result.metadata.tokens_used.unwrap_or(0),
281                    result.metadata.duration_ms
282                ));
283
284                output
285            },
286        }
287    }
288}
289
290/// Extract action type and target description from an AgentAction
291fn extract_action_info(action: &AgentAction) -> (String, String) {
292    match action {
293        AgentAction::WriteFile { path, .. } => ("file_write".to_string(), path.clone()),
294        AgentAction::EditFile { path, .. } => ("edit_file".to_string(), path.clone()),
295        AgentAction::ExecuteCommand { command, .. } => ("command".to_string(), command.clone()),
296        AgentAction::ReadFile { paths } => {
297            if paths.len() == 1 {
298                ("file_read".to_string(), paths[0].clone())
299            } else {
300                ("file_read".to_string(), format!("{} files", paths.len()))
301            }
302        }
303        AgentAction::CreateDirectory { path } => ("create_dir".to_string(), path.clone()),
304        AgentAction::DeleteFile { path } => ("delete_file".to_string(), path.clone()),
305        AgentAction::GitDiff { paths } => {
306            if paths.len() == 1 {
307                ("git_diff".to_string(), paths[0].as_deref().unwrap_or("*").to_string())
308            } else {
309                ("git_diff".to_string(), format!("{} paths", paths.len()))
310            }
311        }
312        AgentAction::GitStatus => ("git_status".to_string(), "git status".to_string()),
313        AgentAction::GitCommit { message, .. } => ("git_commit".to_string(), message.clone()),
314        AgentAction::WebSearch { queries } => {
315            if queries.len() == 1 {
316                ("web_search".to_string(), queries[0].0.clone())
317            } else {
318                ("web_search".to_string(), format!("{} queries", queries.len()))
319            }
320        }
321        AgentAction::WebFetch { url } => ("web_fetch".to_string(), url.clone()),
322    }
323}