Skip to main content

dot/agent/
mod.rs

1mod events;
2mod profile;
3
4pub use events::AgentEvent;
5pub use profile::AgentProfile;
6
7use events::PendingToolCall;
8
9use crate::config::Config;
10use crate::db::Db;
11use crate::provider::{ContentBlock, Message, Provider, Role, StreamEventType, Usage};
12use crate::tools::ToolRegistry;
13use anyhow::Result;
14use tokio::sync::mpsc::UnboundedSender;
15
16const COMPACT_CONTEXT_LIMIT: u32 = 200_000;
17const COMPACT_THRESHOLD: f32 = 0.8;
18const COMPACT_KEEP_MESSAGES: usize = 10;
19
20pub struct Agent {
21    providers: Vec<Box<dyn Provider>>,
22    active: usize,
23    tools: ToolRegistry,
24    db: Db,
25    conversation_id: String,
26    messages: Vec<Message>,
27    profiles: Vec<AgentProfile>,
28    active_profile: usize,
29    pub thinking_budget: u32,
30    cwd: String,
31    agents_context: crate::context::AgentsContext,
32    last_input_tokens: u32,
33}
34
35impl Agent {
36    pub fn new(
37        providers: Vec<Box<dyn Provider>>,
38        db: Db,
39        _config: &Config,
40        tools: ToolRegistry,
41        profiles: Vec<AgentProfile>,
42        cwd: String,
43        agents_context: crate::context::AgentsContext,
44    ) -> Result<Self> {
45        assert!(!providers.is_empty(), "at least one provider required");
46        let conversation_id =
47            db.create_conversation(providers[0].model(), providers[0].name(), &cwd)?;
48        tracing::debug!("Agent created with conversation {}", conversation_id);
49        let profiles = if profiles.is_empty() {
50            vec![AgentProfile::default_profile()]
51        } else {
52            profiles
53        };
54        Ok(Agent {
55            providers,
56            active: 0,
57            tools,
58            db,
59            conversation_id,
60            messages: Vec::new(),
61            profiles,
62            active_profile: 0,
63            thinking_budget: 0,
64            cwd,
65            agents_context,
66            last_input_tokens: 0,
67        })
68    }
69    fn provider(&self) -> &dyn Provider {
70        &*self.providers[self.active]
71    }
72    fn provider_mut(&mut self) -> &mut dyn Provider {
73        &mut *self.providers[self.active]
74    }
75    fn profile(&self) -> &AgentProfile {
76        &self.profiles[self.active_profile]
77    }
78    pub fn conversation_id(&self) -> &str {
79        &self.conversation_id
80    }
81    pub fn messages(&self) -> &[Message] {
82        &self.messages
83    }
84    pub fn set_model(&mut self, model: String) {
85        self.provider_mut().set_model(model);
86    }
87    pub fn set_active_provider(&mut self, provider_name: &str, model: &str) {
88        if let Some(idx) = self
89            .providers
90            .iter()
91            .position(|p| p.name() == provider_name)
92        {
93            self.active = idx;
94            self.providers[idx].set_model(model.to_string());
95        }
96    }
97    pub fn set_thinking_budget(&mut self, budget: u32) {
98        self.thinking_budget = budget;
99    }
100    pub fn available_models(&self) -> Vec<String> {
101        self.provider().available_models()
102    }
103    pub async fn fetch_all_models(&self) -> Vec<(String, Vec<String>)> {
104        let mut result = Vec::new();
105        for p in &self.providers {
106            let models = match p.fetch_models().await {
107                Ok(m) => m,
108                Err(e) => {
109                    tracing::warn!("Failed to fetch models for {}: {e}", p.name());
110                    Vec::new()
111                }
112            };
113            result.push((p.name().to_string(), models));
114        }
115        result
116    }
117    pub fn current_model(&self) -> &str {
118        self.provider().model()
119    }
120    pub fn current_provider_name(&self) -> &str {
121        self.provider().name()
122    }
123    pub fn current_agent_name(&self) -> &str {
124        &self.profile().name
125    }
126    pub fn agent_profiles(&self) -> &[AgentProfile] {
127        &self.profiles
128    }
129    pub fn switch_agent(&mut self, name: &str) -> bool {
130        if let Some(idx) = self.profiles.iter().position(|p| p.name == name) {
131            self.active_profile = idx;
132            let model_spec = self.profiles[idx].model_spec.clone();
133
134            if let Some(spec) = model_spec {
135                let (provider, model) = Config::parse_model_spec(&spec);
136                if let Some(prov) = provider {
137                    self.set_active_provider(prov, model);
138                } else {
139                    self.set_model(model.to_string());
140                }
141            }
142            tracing::info!("Switched to agent '{}'", name);
143            true
144        } else {
145            false
146        }
147    }
148    pub fn new_conversation(&mut self) -> Result<()> {
149        if self.messages.is_empty() {
150            let _ = self.db.delete_conversation(&self.conversation_id);
151        }
152        let conversation_id = self.db.create_conversation(
153            self.provider().model(),
154            self.provider().name(),
155            &self.cwd,
156        )?;
157        self.conversation_id = conversation_id;
158        self.messages.clear();
159        Ok(())
160    }
161    pub fn resume_conversation(&mut self, conversation: &crate::db::Conversation) -> Result<()> {
162        self.conversation_id = conversation.id.clone();
163        self.messages = conversation
164            .messages
165            .iter()
166            .map(|m| Message {
167                role: if m.role == "user" {
168                    Role::User
169                } else {
170                    Role::Assistant
171                },
172                content: vec![ContentBlock::Text(m.content.clone())],
173            })
174            .collect();
175        tracing::debug!("Resumed conversation {}", conversation.id);
176        Ok(())
177    }
178    pub fn list_sessions(&self) -> Result<Vec<crate::db::ConversationSummary>> {
179        self.db.list_conversations_for_cwd(&self.cwd, 50)
180    }
181    pub fn get_session(&self, id: &str) -> Result<crate::db::Conversation> {
182        self.db.get_conversation(id)
183    }
184    pub fn conversation_title(&self) -> Option<String> {
185        self.db
186            .get_conversation(&self.conversation_id)
187            .ok()
188            .and_then(|c| c.title)
189    }
190    pub fn cwd(&self) -> &str {
191        &self.cwd
192    }
193    fn should_compact(&self) -> bool {
194        let threshold = (COMPACT_CONTEXT_LIMIT as f32 * COMPACT_THRESHOLD) as u32;
195        self.last_input_tokens >= threshold
196    }
197    async fn compact(&mut self, event_tx: &UnboundedSender<AgentEvent>) -> Result<()> {
198        let keep = COMPACT_KEEP_MESSAGES;
199        if self.messages.len() <= keep + 2 {
200            return Ok(());
201        }
202        let cutoff = self.messages.len() - keep;
203        let old_messages = self.messages[..cutoff].to_vec();
204        let kept = self.messages[cutoff..].to_vec();
205
206        let mut summary_text = String::new();
207        for msg in &old_messages {
208            let role = match msg.role {
209                Role::User => "User",
210                Role::Assistant => "Assistant",
211                Role::System => "System",
212            };
213            for block in &msg.content {
214                if let ContentBlock::Text(t) = block {
215                    summary_text.push_str(&format!("{}:\n{}\n\n", role, t));
216                }
217            }
218        }
219        let summary_request = vec![Message {
220            role: Role::User,
221            content: vec![ContentBlock::Text(format!(
222                "Summarize the following conversation history concisely, preserving all key decisions, facts, code changes, and context that would be needed to continue the work:\n\n{}",
223                summary_text
224            ))],
225        }];
226
227        let mut stream_rx = self
228            .provider()
229            .stream(
230                &summary_request,
231                Some("You are a concise summarizer. Produce a dense, factual summary."),
232                &[],
233                4096,
234                0,
235            )
236            .await?;
237        let mut full_summary = String::new();
238        while let Some(event) = stream_rx.recv().await {
239            if let StreamEventType::TextDelta(text) = event.event_type {
240                full_summary.push_str(&text);
241            }
242        }
243        self.messages = vec![
244            Message {
245                role: Role::User,
246                content: vec![ContentBlock::Text(
247                    "[Previous conversation summarized below]".to_string(),
248                )],
249            },
250            Message {
251                role: Role::Assistant,
252                content: vec![ContentBlock::Text(format!(
253                    "Summary of prior context:\n\n{}",
254                    full_summary
255                ))],
256            },
257        ];
258        self.messages.extend(kept);
259
260        let _ = self.db.add_message(
261            &self.conversation_id,
262            "assistant",
263            &format!("[Compacted {} messages into summary]", cutoff),
264        );
265        self.last_input_tokens = 0;
266        let _ = event_tx.send(AgentEvent::Compacted {
267            messages_removed: cutoff,
268        });
269        Ok(())
270    }
271    pub async fn send_message(
272        &mut self,
273        content: &str,
274        event_tx: UnboundedSender<AgentEvent>,
275    ) -> Result<()> {
276        self.send_message_with_images(content, Vec::new(), event_tx)
277            .await
278    }
279
280    pub async fn send_message_with_images(
281        &mut self,
282        content: &str,
283        images: Vec<(String, String)>,
284        event_tx: UnboundedSender<AgentEvent>,
285    ) -> Result<()> {
286        if self.should_compact() {
287            self.compact(&event_tx).await?;
288        }
289        self.db
290            .add_message(&self.conversation_id, "user", content)?;
291        let mut blocks: Vec<ContentBlock> = Vec::new();
292        for (media_type, data) in images {
293            blocks.push(ContentBlock::Image { media_type, data });
294        }
295        blocks.push(ContentBlock::Text(content.to_string()));
296        self.messages.push(Message {
297            role: Role::User,
298            content: blocks,
299        });
300        if self.messages.len() == 1 {
301            let title: String = content.chars().take(60).collect();
302            let _ = self
303                .db
304                .update_conversation_title(&self.conversation_id, &title);
305        }
306        let mut final_usage: Option<Usage> = None;
307        let system_prompt = self
308            .agents_context
309            .apply_to_system_prompt(&self.profile().system_prompt);
310        let tool_filter = self.profile().tool_filter.clone();
311        let thinking_budget = self.thinking_budget;
312        loop {
313            let tool_defs = self.tools.definitions_filtered(&tool_filter);
314            let mut stream_rx = self
315                .provider()
316                .stream(
317                    &self.messages,
318                    Some(&system_prompt),
319                    &tool_defs,
320                    8192,
321                    thinking_budget,
322                )
323                .await?;
324            let mut full_text = String::new();
325            let mut full_thinking = String::new();
326            let mut full_thinking_signature = String::new();
327            let mut tool_calls: Vec<PendingToolCall> = Vec::new();
328            let mut current_tool_input = String::new();
329            while let Some(event) = stream_rx.recv().await {
330                match event.event_type {
331                    StreamEventType::TextDelta(text) => {
332                        full_text.push_str(&text);
333                        let _ = event_tx.send(AgentEvent::TextDelta(text));
334                    }
335                    StreamEventType::ThinkingDelta(text) => {
336                        full_thinking.push_str(&text);
337                        let _ = event_tx.send(AgentEvent::ThinkingDelta(text));
338                    }
339                    StreamEventType::ThinkingComplete {
340                        thinking,
341                        signature,
342                    } => {
343                        full_thinking = thinking;
344                        full_thinking_signature = signature;
345                    }
346                    StreamEventType::ToolUseStart { id, name } => {
347                        current_tool_input.clear();
348                        let _ = event_tx.send(AgentEvent::ToolCallStart {
349                            id: id.clone(),
350                            name: name.clone(),
351                        });
352                        tool_calls.push(PendingToolCall {
353                            id,
354                            name,
355                            input: String::new(),
356                        });
357                    }
358                    StreamEventType::ToolUseInputDelta(delta) => {
359                        current_tool_input.push_str(&delta);
360                        let _ = event_tx.send(AgentEvent::ToolCallInputDelta(delta));
361                    }
362                    StreamEventType::ToolUseEnd => {
363                        if let Some(tc) = tool_calls.last_mut() {
364                            tc.input = current_tool_input.clone();
365                        }
366                        current_tool_input.clear();
367                    }
368                    StreamEventType::MessageEnd {
369                        stop_reason: _,
370                        usage,
371                    } => {
372                        self.last_input_tokens = usage.input_tokens;
373                        final_usage = Some(usage);
374                    }
375
376                    _ => {}
377                }
378            }
379
380            let mut content_blocks: Vec<ContentBlock> = Vec::new();
381            if !full_thinking.is_empty() {
382                content_blocks.push(ContentBlock::Thinking {
383                    thinking: full_thinking.clone(),
384                    signature: full_thinking_signature.clone(),
385                });
386            }
387            if !full_text.is_empty() {
388                content_blocks.push(ContentBlock::Text(full_text.clone()));
389            }
390
391            for tc in &tool_calls {
392                let input_value: serde_json::Value =
393                    serde_json::from_str(&tc.input).unwrap_or(serde_json::Value::Null);
394                content_blocks.push(ContentBlock::ToolUse {
395                    id: tc.id.clone(),
396                    name: tc.name.clone(),
397                    input: input_value,
398                });
399            }
400
401            self.messages.push(Message {
402                role: Role::Assistant,
403                content: content_blocks,
404            });
405            let stored_text = if !full_text.is_empty() {
406                full_text.clone()
407            } else {
408                String::from("[tool use]")
409            };
410            let assistant_msg_id =
411                self.db
412                    .add_message(&self.conversation_id, "assistant", &stored_text)?;
413            for tc in &tool_calls {
414                let _ = self
415                    .db
416                    .add_tool_call(&assistant_msg_id, &tc.id, &tc.name, &tc.input);
417            }
418            if tool_calls.is_empty() {
419                let _ = event_tx.send(AgentEvent::TextComplete(full_text));
420                if let Some(usage) = final_usage {
421                    let _ = event_tx.send(AgentEvent::Done { usage });
422                }
423                break;
424            }
425
426            let mut result_blocks: Vec<ContentBlock> = Vec::new();
427
428            for tc in &tool_calls {
429                let input_value: serde_json::Value =
430                    serde_json::from_str(&tc.input).unwrap_or(serde_json::Value::Null);
431                let _ = event_tx.send(AgentEvent::ToolCallExecuting {
432                    id: tc.id.clone(),
433                    name: tc.name.clone(),
434                    input: tc.input.clone(),
435                });
436                let tool_name = tc.name.clone();
437                let tool_input = input_value.clone();
438
439                let exec_result = tokio::time::timeout(std::time::Duration::from_secs(30), async {
440                    tokio::task::block_in_place(|| self.tools.execute(&tool_name, tool_input))
441                })
442                .await;
443
444                let (output, is_error) = match exec_result {
445                    Err(_elapsed) => (
446                        format!("Tool '{}' timed out after 30 seconds.", tc.name),
447                        true,
448                    ),
449                    Ok(Err(e)) => (e.to_string(), true),
450                    Ok(Ok(out)) => (out, false),
451                };
452                tracing::debug!(
453                    "Tool '{}' result (error={}): {}",
454                    tc.name,
455                    is_error,
456                    &output[..output.len().min(200)]
457                );
458
459                let _ = self.db.update_tool_result(&tc.id, &output, is_error);
460                let _ = event_tx.send(AgentEvent::ToolCallResult {
461                    id: tc.id.clone(),
462                    name: tc.name.clone(),
463                    output: output.clone(),
464                    is_error,
465                });
466                result_blocks.push(ContentBlock::ToolResult {
467                    tool_use_id: tc.id.clone(),
468                    content: output,
469                    is_error,
470                });
471            }
472
473            self.messages.push(Message {
474                role: Role::User,
475                content: result_blocks,
476            });
477        }
478
479        Ok(())
480    }
481}