1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::path::PathBuf;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6
7use crate::utils::MutexExt;
8
9use crate::{
10 agents::{execute_action, ActionResult as AgentActionResult, AgentAction},
11 app::Config,
12 cli::OutputFormat,
13 context::ContextLoader,
14 models::{ChatMessage, MessageRole, Model, ModelConfig, ModelFactory, ProjectContext, parse_tool_calls, group_parallel_reads},
15};
16
17#[derive(Debug, Serialize, Deserialize)]
19pub struct NonInteractiveResult {
20 pub prompt: String,
22 pub response: String,
24 pub actions: Vec<ActionResult>,
26 pub errors: Vec<String>,
28 pub metadata: ExecutionMetadata,
30}
31
32#[derive(Debug, Serialize, Deserialize)]
33pub struct ActionResult {
34 pub action_type: String,
36 pub target: String,
38 pub success: bool,
40 pub output: Option<String>,
42}
43
44#[derive(Debug, Serialize, Deserialize)]
45pub struct ExecutionMetadata {
46 pub model: String,
48 pub tokens_used: Option<usize>,
50 pub duration_ms: u128,
52 pub actions_executed: bool,
54}
55
56pub struct NonInteractiveRunner {
58 model: Arc<RwLock<Box<dyn Model>>>,
59 context: ProjectContext,
60 no_execute: bool,
61 max_tokens: Option<usize>,
62}
63
64impl NonInteractiveRunner {
65 pub async fn new(
67 model_id: String,
68 project_path: PathBuf,
69 config: Config,
70 no_execute: bool,
71 max_tokens: Option<usize>,
72 backend: Option<&str>,
73 ) -> Result<Self> {
74 let model = ModelFactory::create_with_backend(&model_id, Some(&config), backend).await?;
76
77 let loader = ContextLoader::new()?;
79 let context = loader.load_context(&project_path).await?;
80
81 Ok(Self {
82 model: Arc::new(RwLock::new(model)),
83 context,
84 no_execute,
85 max_tokens,
86 })
87 }
88
89 pub async fn execute(&self, prompt: String) -> Result<NonInteractiveResult> {
91 let start_time = std::time::Instant::now();
92 let mut errors = Vec::new();
93 let mut actions = Vec::new();
94
95 let prompt_lower = prompt.to_lowercase();
97 let is_code_related = prompt_lower.contains("code")
98 || prompt_lower.contains("file")
99 || prompt_lower.contains("function")
100 || prompt_lower.contains("class")
101 || prompt_lower.contains("implement")
102 || prompt_lower.contains("create")
103 || prompt_lower.contains("write")
104 || prompt_lower.contains("debug")
105 || prompt_lower.contains("fix")
106 || prompt_lower.contains("test")
107 || prompt_lower.contains("build")
108 || prompt_lower.contains("project")
109 || prompt_lower.contains("analyze")
110 || prompt_lower.contains("refactor");
111
112 let system_content = if is_code_related {
113 format!(
114 "You are an AI coding assistant. Here is the project structure:\n\n{}\n\n{}",
115 self.context.to_prompt_context(),
116 "You can use [FILE_WRITE: path] ... [/FILE_WRITE] blocks to write files, and [COMMAND: cmd] ... [/COMMAND] blocks to run commands."
117 )
118 } else {
119 "You are a helpful AI assistant. Answer the user's question directly and concisely."
120 .to_string()
121 };
122
123 let system_message = ChatMessage {
124 role: MessageRole::System,
125 content: system_content,
126 timestamp: chrono::Local::now(),
127 actions: Vec::new(),
128 thinking: None,
129 images: None,
130 tool_calls: None,
131 };
132
133 let user_message = ChatMessage {
134 role: MessageRole::User,
135 content: prompt.clone(),
136 timestamp: chrono::Local::now(),
137 actions: Vec::new(),
138 thinking: None,
139 images: None,
140 tool_calls: None,
141 };
142
143 let messages = vec![system_message, user_message];
144
145 let model_guard = self.model.read().await;
147 let model_name = model_guard.name().to_string();
148 drop(model_guard);
149
150 let model_config = ModelConfig {
152 model: model_name,
153 temperature: 0.7,
154 max_tokens: self.max_tokens.unwrap_or(4096),
155 top_p: Some(1.0),
156 frequency_penalty: None,
157 presence_penalty: None,
158 system_prompt: None,
159 backend_options: std::collections::HashMap::new(),
160 };
161
162 let full_response;
164 let tokens_used;
165
166 let response_text = Arc::new(std::sync::Mutex::new(String::new()));
168 let response_clone = Arc::clone(&response_text);
169 let callback = Arc::new(move |chunk: &str| {
170 let mut resp = response_clone.lock_mut_safe();
171 resp.push_str(chunk);
172 });
173
174 let model_name;
176 let result = {
177 let mut model = self.model.write().await;
178 model_name = model.name().to_string();
179 model
180 .chat(&messages, &self.context, &model_config, Some(callback))
181 .await
182 };
183
184 let parsed_actions = match result {
186 Ok(response) => {
187 let callback_content = response_text.lock_mut_safe().clone();
189 if !callback_content.is_empty() {
190 full_response = callback_content;
191 } else {
192 full_response = response.content;
193 }
194 tokens_used = response.usage.map(|u| u.total_tokens).unwrap_or(0);
195
196 if let Some(tool_calls) = response.tool_calls {
198 let parsed = parse_tool_calls(&tool_calls);
199 group_parallel_reads(parsed)
200 } else {
201 vec![]
202 }
203 },
204 Err(e) => {
205 errors.push(format!("Model error: {}", e));
206 full_response = response_text.lock_mut_safe().clone();
207 tokens_used = 0;
208 vec![]
209 },
210 };
211
212 if !self.no_execute && !parsed_actions.is_empty() {
214 for action in parsed_actions {
215 let (action_type, target) = match &action {
216 AgentAction::WriteFile { path, .. } => ("file_write", path.clone()),
217 AgentAction::ExecuteCommand { command, .. } => ("command", command.clone()),
218 AgentAction::ReadFile { path } => ("file_read", path.clone()),
219 AgentAction::CreateDirectory { path } => ("create_dir", path.clone()),
220 AgentAction::DeleteFile { path } => ("delete_file", path.clone()),
221 AgentAction::GitDiff { .. } => ("git_diff", "git diff".to_string()),
222 AgentAction::GitStatus => ("git_status", "git status".to_string()),
223 AgentAction::GitCommit { message, .. } => ("git_commit", message.clone()),
224 AgentAction::WebSearch { query, .. } => ("web_search", query.clone()),
225 AgentAction::ParallelRead { paths } => ("parallel_read", format!("{} files", paths.len())),
226 AgentAction::ParallelWebSearch { queries } => ("parallel_web_search", format!("{} queries", queries.len())),
227 AgentAction::ParallelGitDiff { paths } => ("parallel_git_diff", format!("{} paths", paths.len())),
228 };
229
230 let result = execute_action(&action)
231 .await
232 .unwrap_or(AgentActionResult::Error {
233 error: "Failed to execute action".to_string(),
234 });
235
236 let action_result = match result {
237 AgentActionResult::Success { output } => ActionResult {
238 action_type: action_type.to_string(),
239 target,
240 success: true,
241 output: Some(output),
242 },
243 AgentActionResult::Error { error } => ActionResult {
244 action_type: action_type.to_string(),
245 target,
246 success: false,
247 output: Some(error),
248 },
249 };
250
251 actions.push(action_result);
252 }
253 } else if !parsed_actions.is_empty() {
254 for action in parsed_actions {
256 let (action_type, target) = match action {
257 AgentAction::WriteFile { path, .. } => ("file_write", path),
258 AgentAction::ExecuteCommand { command, .. } => ("command", command),
259 AgentAction::ReadFile { path } => ("file_read", path),
260 AgentAction::CreateDirectory { path } => ("create_dir", path),
261 AgentAction::DeleteFile { path } => ("delete_file", path),
262 AgentAction::GitDiff { .. } => ("git_diff", "git diff".to_string()),
263 AgentAction::GitStatus => ("git_status", "git status".to_string()),
264 AgentAction::GitCommit { message, .. } => ("git_commit", message),
265 AgentAction::WebSearch { query, .. } => ("web_search", query),
266 AgentAction::ParallelRead { paths } => ("parallel_read", format!("{} files", paths.len())),
267 AgentAction::ParallelWebSearch { queries } => ("parallel_web_search", format!("{} queries", queries.len())),
268 AgentAction::ParallelGitDiff { paths } => ("parallel_git_diff", format!("{} paths", paths.len())),
269 };
270
271 actions.push(ActionResult {
272 action_type: action_type.to_string(),
273 target,
274 success: false,
275 output: Some("Not executed (--no-execute mode)".to_string()),
276 });
277 }
278 }
279
280 let duration_ms = start_time.elapsed().as_millis();
281 let actions_executed = !self.no_execute && !actions.is_empty();
282
283 Ok(NonInteractiveResult {
284 prompt,
285 response: full_response,
286 actions,
287 errors,
288 metadata: ExecutionMetadata {
289 model: model_name,
290 tokens_used: Some(tokens_used),
291 duration_ms,
292 actions_executed,
293 },
294 })
295 }
296
297 pub fn format_result(&self, result: &NonInteractiveResult, format: OutputFormat) -> String {
299 match format {
300 OutputFormat::Json => serde_json::to_string_pretty(result).unwrap_or_else(|e| {
301 format!("{{\"error\": \"Failed to serialize result: {}\"}}", e)
302 }),
303 OutputFormat::Text => {
304 let mut output = String::new();
305 output.push_str(&result.response);
306
307 if !result.actions.is_empty() {
308 output.push_str("\n\n--- Actions ---\n");
309 for action in &result.actions {
310 output.push_str(&format!(
311 "[{}] {} - {}\n",
312 if action.success { "OK" } else { "FAIL" },
313 action.action_type,
314 action.target
315 ));
316 if let Some(ref out) = action.output {
317 output.push_str(&format!(" {}\n", out));
318 }
319 }
320 }
321
322 if !result.errors.is_empty() {
323 output.push_str("\n--- Errors ---\n");
324 for error in &result.errors {
325 output.push_str(&format!("• {}\n", error));
326 }
327 }
328
329 output
330 },
331 OutputFormat::Markdown => {
332 let mut output = String::new();
333
334 output.push_str("## Response\n\n");
335 output.push_str(&result.response);
336 output.push_str("\n\n");
337
338 if !result.actions.is_empty() {
339 output.push_str("## Actions Executed\n\n");
340 for action in &result.actions {
341 let status = if action.success { "SUCCESS" } else { "FAILED" };
342 output.push_str(&format!(
343 "- {} **{}**: `{}`\n",
344 status, action.action_type, action.target
345 ));
346 if let Some(ref out) = action.output {
347 output.push_str(&format!(" ```\n {}\n ```\n", out));
348 }
349 }
350 output.push_str("\n");
351 }
352
353 if !result.errors.is_empty() {
354 output.push_str("## Errors\n\n");
355 for error in &result.errors {
356 output.push_str(&format!("- {}\n", error));
357 }
358 output.push_str("\n");
359 }
360
361 output.push_str("---\n");
362 output.push_str(&format!(
363 "*Model: {} | Tokens: {} | Duration: {}ms*\n",
364 result.metadata.model,
365 result.metadata.tokens_used.unwrap_or(0),
366 result.metadata.duration_ms
367 ));
368
369 output
370 },
371 }
372 }
373}