Skip to main content

codetether_agent/session/
mod.rs

1//! Session management
2//!
3//! Sessions track the conversation history and state for agent interactions.
4
5use crate::agent::ToolUse;
6use crate::provider::{Message, Usage};
7use crate::tool::ToolRegistry;
8use anyhow::Result;
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::path::PathBuf;
12use std::sync::Arc;
13use tokio::fs;
14use uuid::Uuid;
15
16/// A conversation session
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct Session {
19    pub id: String,
20    pub title: Option<String>,
21    pub created_at: DateTime<Utc>,
22    pub updated_at: DateTime<Utc>,
23    pub messages: Vec<Message>,
24    pub tool_uses: Vec<ToolUse>,
25    pub usage: Usage,
26    pub agent: String,
27    pub metadata: SessionMetadata,
28}
29
30#[derive(Debug, Clone, Default, Serialize, Deserialize)]
31pub struct SessionMetadata {
32    pub directory: Option<PathBuf>,
33    pub model: Option<String>,
34    pub shared: bool,
35    pub share_url: Option<String>,
36}
37
38impl Session {
39    fn default_model_for_provider(provider: &str) -> String {
40        match provider {
41            "moonshotai" => "kimi-k2.5".to_string(),
42            "anthropic" => "claude-sonnet-4-20250514".to_string(),
43            "openai" => "gpt-4o".to_string(),
44            "google" => "gemini-2.5-pro".to_string(),
45            "zhipuai" => "glm-4.7".to_string(),
46            "openrouter" => "zhipuai/glm-4.7".to_string(),
47            "novita" => "qwen/qwen3-coder-next".to_string(),
48            "github-copilot" | "github-copilot-enterprise" => "gpt-5-mini".to_string(),
49            _ => "glm-4.7".to_string(),
50        }
51    }
52
53    /// Create a new session
54    pub async fn new() -> Result<Self> {
55        let id = Uuid::new_v4().to_string();
56        let now = Utc::now();
57
58        Ok(Self {
59            id,
60            title: None,
61            created_at: now,
62            updated_at: now,
63            messages: Vec::new(),
64            tool_uses: Vec::new(),
65            usage: Usage::default(),
66            agent: "build".to_string(),
67            metadata: SessionMetadata {
68                directory: Some(std::env::current_dir()?),
69                ..Default::default()
70            },
71        })
72    }
73
74    /// Load an existing session
75    pub async fn load(id: &str) -> Result<Self> {
76        let path = Self::session_path(id)?;
77        let content = fs::read_to_string(&path).await?;
78        let session: Session = serde_json::from_str(&content)?;
79        Ok(session)
80    }
81
82    /// Load the last session
83    pub async fn last() -> Result<Self> {
84        let sessions_dir = Self::sessions_dir()?;
85
86        if !sessions_dir.exists() {
87            anyhow::bail!("No sessions found");
88        }
89
90        let mut entries: Vec<tokio::fs::DirEntry> = Vec::new();
91        let mut read_dir = fs::read_dir(&sessions_dir).await?;
92        while let Some(entry) = read_dir.next_entry().await? {
93            entries.push(entry);
94        }
95
96        if entries.is_empty() {
97            anyhow::bail!("No sessions found");
98        }
99
100        // Sort by modification time (most recent first)
101        // Use std::fs::metadata since we can't await in sort_by_key
102        entries.sort_by_key(|e| {
103            std::cmp::Reverse(
104                std::fs::metadata(e.path())
105                    .ok()
106                    .and_then(|m| m.modified().ok())
107                    .unwrap_or(std::time::SystemTime::UNIX_EPOCH),
108            )
109        });
110
111        if let Some(entry) = entries.first() {
112            let content: String = fs::read_to_string(entry.path()).await?;
113            let session: Session = serde_json::from_str(&content)?;
114            return Ok(session);
115        }
116
117        anyhow::bail!("No sessions found")
118    }
119
120    /// Save the session to disk
121    pub async fn save(&self) -> Result<()> {
122        let path = Self::session_path(&self.id)?;
123
124        if let Some(parent) = path.parent() {
125            fs::create_dir_all(parent).await?;
126        }
127
128        let content = serde_json::to_string_pretty(self)?;
129        fs::write(&path, content).await?;
130
131        Ok(())
132    }
133
134    /// Add a message to the session
135    pub fn add_message(&mut self, message: Message) {
136        self.messages.push(message);
137        self.updated_at = Utc::now();
138    }
139
140    /// Execute a prompt and get the result
141    pub async fn prompt(&mut self, message: &str) -> Result<SessionResult> {
142        use crate::provider::{
143            CompletionRequest, ContentPart, ProviderRegistry, Role, parse_model_string,
144        };
145
146        // Load providers from Vault
147        let registry = ProviderRegistry::from_vault().await?;
148
149        let providers = registry.list();
150        if providers.is_empty() {
151            anyhow::bail!(
152                "No providers available. Configure API keys in HashiCorp Vault (for Copilot use `codetether auth copilot`)."
153            );
154        }
155
156        tracing::info!("Available providers: {:?}", providers);
157
158        // Parse model string (format: "provider/model", "provider", or just "model")
159        let (provider_name, model_id) = if let Some(ref model_str) = self.metadata.model {
160            let (prov, model) = parse_model_string(model_str);
161            if prov.is_some() {
162                // Format: provider/model
163                (prov.map(|s| s.to_string()), model.to_string())
164            } else if providers.contains(&model) {
165                // Format: just provider name (e.g., "novita")
166                (Some(model.to_string()), String::new())
167            } else {
168                // Format: just model name
169                (None, model.to_string())
170            }
171        } else {
172            (None, String::new())
173        };
174
175        // Determine which provider to use (prefer zhipuai as default)
176        let selected_provider = provider_name
177            .as_deref()
178            .filter(|p| providers.contains(p))
179            .or_else(|| {
180                if providers.contains(&"zhipuai") {
181                    Some("zhipuai")
182                } else {
183                    providers.first().copied()
184                }
185            })
186            .ok_or_else(|| anyhow::anyhow!("No providers available"))?;
187
188        let provider = registry
189            .get(selected_provider)
190            .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
191
192        // Add user message to session using add_message
193        self.add_message(Message {
194            role: Role::User,
195            content: vec![ContentPart::Text {
196                text: message.to_string(),
197            }],
198        });
199
200        // Generate title if this is the first user message and no title exists
201        if self.title.is_none() {
202            self.generate_title().await?;
203        }
204
205        // Determine model to use
206        let model = if !model_id.is_empty() {
207            model_id
208        } else {
209            Self::default_model_for_provider(selected_provider)
210        };
211
212        // Create tool registry with all available tools
213        let tool_registry = ToolRegistry::with_provider_arc(Arc::clone(&provider), model.clone());
214        let tool_definitions = tool_registry.definitions();
215
216        // Kimi K2.5 requires temperature=1.0
217        let temperature = if model.starts_with("kimi-k2") {
218            Some(1.0)
219        } else {
220            Some(0.7)
221        };
222
223        tracing::info!("Using model: {} via provider: {}", model, selected_provider);
224        tracing::info!("Available tools: {}", tool_definitions.len());
225
226        // Build system prompt with AGENTS.md
227        let cwd = self
228            .metadata
229            .directory
230            .clone()
231            .unwrap_or_else(|| std::env::current_dir().unwrap_or_default());
232        let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
233
234        // Run agentic loop with tool execution
235        let max_steps = 50;
236        let mut final_output = String::new();
237
238        for step in 1..=max_steps {
239            tracing::info!(step = step, "Agent step starting");
240
241            // Build messages with system prompt first
242            let mut messages = vec![Message {
243                role: Role::System,
244                content: vec![ContentPart::Text {
245                    text: system_prompt.clone(),
246                }],
247            }];
248            messages.extend(self.messages.clone());
249
250            // Create completion request with tools
251            let request = CompletionRequest {
252                messages,
253                tools: tool_definitions.clone(),
254                model: model.clone(),
255                temperature,
256                top_p: None,
257                max_tokens: Some(8192),
258                stop: Vec::new(),
259            };
260
261            // Call the provider
262            let response = provider.complete(request).await?;
263
264            // Record token usage
265            crate::telemetry::TOKEN_USAGE.record_model_usage(
266                &model,
267                response.usage.prompt_tokens as u64,
268                response.usage.completion_tokens as u64,
269            );
270
271            // Extract tool calls from response
272            let tool_calls: Vec<(String, String, serde_json::Value)> = response
273                .message
274                .content
275                .iter()
276                .filter_map(|part| {
277                    if let ContentPart::ToolCall {
278                        id,
279                        name,
280                        arguments,
281                    } = part
282                    {
283                        // Parse arguments JSON string into Value
284                        let args: serde_json::Value =
285                            serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
286                        Some((id.clone(), name.clone(), args))
287                    } else {
288                        None
289                    }
290                })
291                .collect();
292
293            // Collect text output
294            for part in &response.message.content {
295                if let ContentPart::Text { text } = part {
296                    if !text.is_empty() {
297                        final_output.push_str(text);
298                        final_output.push('\n');
299                    }
300                }
301            }
302
303            // If no tool calls, we're done
304            if tool_calls.is_empty() {
305                self.add_message(response.message.clone());
306                break;
307            }
308
309            // Add assistant message with tool calls
310            self.add_message(response.message.clone());
311
312            tracing::info!(
313                step = step,
314                num_tools = tool_calls.len(),
315                "Executing tool calls"
316            );
317
318            // Execute each tool call
319            for (tool_id, tool_name, tool_input) in tool_calls {
320                tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
321
322                // Get and execute the tool
323                let content = if let Some(tool) = tool_registry.get(&tool_name) {
324                    match tool.execute(tool_input.clone()).await {
325                        Ok(result) => {
326                            tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
327                            result.output
328                        }
329                        Err(e) => {
330                            tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
331                            format!("Error: {}", e)
332                        }
333                    }
334                } else {
335                    tracing::warn!(tool = %tool_name, "Tool not found");
336                    format!("Error: Unknown tool '{}'", tool_name)
337                };
338
339                // Add tool result message
340                self.add_message(Message {
341                    role: Role::Tool,
342                    content: vec![ContentPart::ToolResult {
343                        tool_call_id: tool_id,
344                        content,
345                    }],
346                });
347            }
348        }
349
350        // Save session after each prompt to persist messages
351        self.save().await?;
352
353        Ok(SessionResult {
354            text: final_output.trim().to_string(),
355            session_id: self.id.clone(),
356        })
357    }
358
359    /// Process a user message with real-time event streaming for UI updates.
360    /// Events are sent through the provided channel as tool calls execute.
361    pub async fn prompt_with_events(
362        &mut self,
363        message: &str,
364        event_tx: tokio::sync::mpsc::Sender<SessionEvent>,
365    ) -> Result<SessionResult> {
366        use crate::provider::{
367            CompletionRequest, ContentPart, ProviderRegistry, Role, parse_model_string,
368        };
369
370        let _ = event_tx.send(SessionEvent::Thinking).await;
371
372        // Load provider registry from Vault
373        let registry = ProviderRegistry::from_vault().await?;
374        let providers = registry.list();
375        if providers.is_empty() {
376            anyhow::bail!(
377                "No providers available. Configure API keys in HashiCorp Vault (for Copilot use `codetether auth copilot`)."
378            );
379        }
380        tracing::info!("Available providers: {:?}", providers);
381
382        // Parse model string (format: "provider/model", "provider", or just "model")
383        let (provider_name, model_id) = if let Some(ref model_str) = self.metadata.model {
384            let (prov, model) = parse_model_string(model_str);
385            if prov.is_some() {
386                (prov.map(|s| s.to_string()), model.to_string())
387            } else if providers.contains(&model) {
388                (Some(model.to_string()), String::new())
389            } else {
390                (None, model.to_string())
391            }
392        } else {
393            (None, String::new())
394        };
395
396        // Determine which provider to use (prefer zhipuai as default)
397        let selected_provider = provider_name
398            .as_deref()
399            .filter(|p| providers.contains(p))
400            .or_else(|| {
401                if providers.contains(&"zhipuai") {
402                    Some("zhipuai")
403                } else {
404                    providers.first().copied()
405                }
406            })
407            .ok_or_else(|| anyhow::anyhow!("No providers available"))?;
408
409        let provider = registry
410            .get(selected_provider)
411            .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
412
413        // Add user message
414        self.add_message(Message {
415            role: Role::User,
416            content: vec![ContentPart::Text {
417                text: message.to_string(),
418            }],
419        });
420
421        // Generate title if needed
422        if self.title.is_none() {
423            self.generate_title().await?;
424        }
425
426        // Determine model
427        let model = if !model_id.is_empty() {
428            model_id
429        } else {
430            Self::default_model_for_provider(selected_provider)
431        };
432
433        // Create tool registry
434        let tool_registry = ToolRegistry::with_provider_arc(Arc::clone(&provider), model.clone());
435        let tool_definitions = tool_registry.definitions();
436
437        let temperature = if model.starts_with("kimi-k2") {
438            Some(1.0)
439        } else {
440            Some(0.7)
441        };
442
443        tracing::info!("Using model: {} via provider: {}", model, selected_provider);
444        tracing::info!("Available tools: {}", tool_definitions.len());
445
446        // Build system prompt
447        let cwd = std::env::var("PWD")
448            .map(std::path::PathBuf::from)
449            .unwrap_or_else(|_| std::env::current_dir().unwrap_or_default());
450        let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
451
452        let mut final_output = String::new();
453        let max_steps = 50;
454
455        for step in 1..=max_steps {
456            tracing::info!(step = step, "Agent step starting");
457            let _ = event_tx.send(SessionEvent::Thinking).await;
458
459            // Build messages with system prompt first
460            let mut messages = vec![Message {
461                role: Role::System,
462                content: vec![ContentPart::Text {
463                    text: system_prompt.clone(),
464                }],
465            }];
466            messages.extend(self.messages.clone());
467
468            let request = CompletionRequest {
469                messages,
470                tools: tool_definitions.clone(),
471                model: model.clone(),
472                temperature,
473                top_p: None,
474                max_tokens: Some(8192),
475                stop: Vec::new(),
476            };
477
478            let response = provider.complete(request).await?;
479
480            crate::telemetry::TOKEN_USAGE.record_model_usage(
481                &model,
482                response.usage.prompt_tokens as u64,
483                response.usage.completion_tokens as u64,
484            );
485
486            // Extract tool calls
487            let tool_calls: Vec<(String, String, serde_json::Value)> = response
488                .message
489                .content
490                .iter()
491                .filter_map(|part| {
492                    if let ContentPart::ToolCall {
493                        id,
494                        name,
495                        arguments,
496                    } = part
497                    {
498                        let args: serde_json::Value =
499                            serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
500                        Some((id.clone(), name.clone(), args))
501                    } else {
502                        None
503                    }
504                })
505                .collect();
506
507            // Collect text output
508            for part in &response.message.content {
509                if let ContentPart::Text { text } = part {
510                    if !text.is_empty() {
511                        final_output.push_str(text);
512                        final_output.push('\n');
513                        let _ = event_tx.send(SessionEvent::TextChunk(text.clone())).await;
514                    }
515                }
516            }
517
518            if tool_calls.is_empty() {
519                self.add_message(response.message.clone());
520                break;
521            }
522
523            self.add_message(response.message.clone());
524
525            tracing::info!(
526                step = step,
527                num_tools = tool_calls.len(),
528                "Executing tool calls"
529            );
530
531            // Execute each tool call with events
532            for (tool_id, tool_name, tool_input) in tool_calls {
533                let args_str = serde_json::to_string(&tool_input).unwrap_or_default();
534                let _ = event_tx
535                    .send(SessionEvent::ToolCallStart {
536                        name: tool_name.clone(),
537                        arguments: args_str,
538                    })
539                    .await;
540
541                tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
542
543                let (content, success) = if let Some(tool) = tool_registry.get(&tool_name) {
544                    match tool.execute(tool_input.clone()).await {
545                        Ok(result) => {
546                            tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
547                            (result.output, result.success)
548                        }
549                        Err(e) => {
550                            tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
551                            (format!("Error: {}", e), false)
552                        }
553                    }
554                } else {
555                    tracing::warn!(tool = %tool_name, "Tool not found");
556                    (format!("Error: Unknown tool '{}'", tool_name), false)
557                };
558
559                let _ = event_tx
560                    .send(SessionEvent::ToolCallComplete {
561                        name: tool_name.clone(),
562                        output: content.clone(),
563                        success,
564                    })
565                    .await;
566
567                self.add_message(Message {
568                    role: Role::Tool,
569                    content: vec![ContentPart::ToolResult {
570                        tool_call_id: tool_id,
571                        content,
572                    }],
573                });
574            }
575        }
576
577        self.save().await?;
578
579        let _ = event_tx
580            .send(SessionEvent::TextComplete(final_output.trim().to_string()))
581            .await;
582        let _ = event_tx.send(SessionEvent::Done).await;
583
584        Ok(SessionResult {
585            text: final_output.trim().to_string(),
586            session_id: self.id.clone(),
587        })
588    }
589
590    /// Generate a title for the session based on the first message
591    /// Only sets title if not already set (for initial title generation)
592    pub async fn generate_title(&mut self) -> Result<()> {
593        if self.title.is_some() {
594            return Ok(());
595        }
596
597        // Get first user message
598        let first_message = self
599            .messages
600            .iter()
601            .find(|m| m.role == crate::provider::Role::User);
602
603        if let Some(msg) = first_message {
604            let text: String = msg
605                .content
606                .iter()
607                .filter_map(|p| match p {
608                    crate::provider::ContentPart::Text { text } => Some(text.clone()),
609                    _ => None,
610                })
611                .collect::<Vec<_>>()
612                .join(" ");
613
614            // Truncate to reasonable length
615            self.title = Some(if text.len() > 50 {
616                format!("{}...", &text[..47])
617            } else {
618                text
619            });
620        }
621
622        Ok(())
623    }
624
625    /// Regenerate the title based on the first message, even if already set
626    /// Use this for on-demand title updates or after context changes
627    pub async fn regenerate_title(&mut self) -> Result<()> {
628        // Get first user message
629        let first_message = self
630            .messages
631            .iter()
632            .find(|m| m.role == crate::provider::Role::User);
633
634        if let Some(msg) = first_message {
635            let text: String = msg
636                .content
637                .iter()
638                .filter_map(|p| match p {
639                    crate::provider::ContentPart::Text { text } => Some(text.clone()),
640                    _ => None,
641                })
642                .collect::<Vec<_>>()
643                .join(" ");
644
645            // Truncate to reasonable length
646            self.title = Some(if text.len() > 50 {
647                format!("{}...", &text[..47])
648            } else {
649                text
650            });
651        }
652
653        Ok(())
654    }
655
656    /// Set a custom title for the session
657    pub fn set_title(&mut self, title: impl Into<String>) {
658        self.title = Some(title.into());
659        self.updated_at = Utc::now();
660    }
661
662    /// Clear the title, allowing it to be regenerated
663    pub fn clear_title(&mut self) {
664        self.title = None;
665        self.updated_at = Utc::now();
666    }
667
668    /// Handle context change - updates metadata and optionally regenerates title
669    /// Call this when the session context changes (e.g., directory change, model change)
670    pub async fn on_context_change(&mut self, regenerate_title: bool) -> Result<()> {
671        self.updated_at = Utc::now();
672
673        if regenerate_title {
674            self.regenerate_title().await?;
675        }
676
677        Ok(())
678    }
679
680    /// Get the sessions directory
681    fn sessions_dir() -> Result<PathBuf> {
682        crate::config::Config::data_dir()
683            .map(|d| d.join("sessions"))
684            .ok_or_else(|| anyhow::anyhow!("Could not determine data directory"))
685    }
686
687    /// Get the path for a session file
688    fn session_path(id: &str) -> Result<PathBuf> {
689        Ok(Self::sessions_dir()?.join(format!("{}.json", id)))
690    }
691}
692
693/// Result from a session prompt
694#[derive(Debug, Clone, Serialize, Deserialize)]
695pub struct SessionResult {
696    pub text: String,
697    pub session_id: String,
698}
699
700/// Events emitted during session processing for real-time UI updates
701#[derive(Debug, Clone)]
702pub enum SessionEvent {
703    /// Agent is thinking/processing
704    Thinking,
705    /// Tool call started
706    ToolCallStart { name: String, arguments: String },
707    /// Tool call completed with result
708    ToolCallComplete {
709        name: String,
710        output: String,
711        success: bool,
712    },
713    /// Partial text output (for streaming)
714    TextChunk(String),
715    /// Final text output
716    TextComplete(String),
717    /// Processing complete
718    Done,
719    /// Error occurred
720    Error(String),
721}
722
723/// List all sessions
724pub async fn list_sessions() -> Result<Vec<SessionSummary>> {
725    let sessions_dir = crate::config::Config::data_dir()
726        .map(|d| d.join("sessions"))
727        .ok_or_else(|| anyhow::anyhow!("Could not determine data directory"))?;
728
729    if !sessions_dir.exists() {
730        return Ok(Vec::new());
731    }
732
733    let mut summaries = Vec::new();
734    let mut entries = fs::read_dir(&sessions_dir).await?;
735
736    while let Some(entry) = entries.next_entry().await? {
737        let path = entry.path();
738        if path.extension().map(|e| e == "json").unwrap_or(false) {
739            if let Ok(content) = fs::read_to_string(&path).await {
740                if let Ok(session) = serde_json::from_str::<Session>(&content) {
741                    summaries.push(SessionSummary {
742                        id: session.id,
743                        title: session.title,
744                        created_at: session.created_at,
745                        updated_at: session.updated_at,
746                        message_count: session.messages.len(),
747                        agent: session.agent,
748                    });
749                }
750            }
751        }
752    }
753
754    summaries.sort_by(|a, b| b.updated_at.cmp(&a.updated_at));
755    Ok(summaries)
756}
757
758/// Summary of a session for listing
759#[derive(Debug, Clone, Serialize, Deserialize)]
760pub struct SessionSummary {
761    pub id: String,
762    pub title: Option<String>,
763    pub created_at: DateTime<Utc>,
764    pub updated_at: DateTime<Utc>,
765    pub message_count: usize,
766    pub agent: String,
767}
768
769// Async helper for Vec - kept for potential future use
770#[allow(dead_code)]
771use futures::StreamExt;
772
773#[allow(dead_code)]
774trait AsyncCollect<T> {
775    async fn collect(self) -> Vec<T>;
776}
777
778#[allow(dead_code)]
779impl<S, T> AsyncCollect<T> for S
780where
781    S: futures::Stream<Item = T> + Unpin,
782{
783    async fn collect(mut self) -> Vec<T> {
784        let mut items = Vec::new();
785        while let Some(item) = self.next().await {
786            items.push(item);
787        }
788        items
789    }
790}