Skip to main content

codetether_agent/session/
mod.rs

1//! Session management
2//!
3//! Sessions track the conversation history and state for agent interactions.
4
5use crate::agent::ToolUse;
6use crate::provider::{Message, Usage};
7use crate::telemetry::TokenCounts;
8use crate::tool::ToolRegistry;
9use anyhow::Result;
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12use std::path::PathBuf;
13use std::sync::Arc;
14use tokio::fs;
15use uuid::Uuid;
16
17/// A conversation session
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct Session {
20    pub id: String,
21    pub title: Option<String>,
22    pub created_at: DateTime<Utc>,
23    pub updated_at: DateTime<Utc>,
24    pub messages: Vec<Message>,
25    pub tool_uses: Vec<ToolUse>,
26    pub usage: Usage,
27    pub agent: String,
28    pub metadata: SessionMetadata,
29}
30
31#[derive(Debug, Clone, Default, Serialize, Deserialize)]
32pub struct SessionMetadata {
33    pub directory: Option<PathBuf>,
34    pub model: Option<String>,
35    pub shared: bool,
36    pub share_url: Option<String>,
37}
38
39impl Session {
40    /// Create a new session
41    pub async fn new() -> Result<Self> {
42        let id = Uuid::new_v4().to_string();
43        let now = Utc::now();
44
45        Ok(Self {
46            id,
47            title: None,
48            created_at: now,
49            updated_at: now,
50            messages: Vec::new(),
51            tool_uses: Vec::new(),
52            usage: Usage::default(),
53            agent: "build".to_string(),
54            metadata: SessionMetadata {
55                directory: Some(std::env::current_dir()?),
56                ..Default::default()
57            },
58        })
59    }
60
61    /// Load an existing session
62    pub async fn load(id: &str) -> Result<Self> {
63        let path = Self::session_path(id)?;
64        let content = fs::read_to_string(&path).await?;
65        let session: Session = serde_json::from_str(&content)?;
66        Ok(session)
67    }
68
69    /// Load the last session
70    pub async fn last() -> Result<Self> {
71        let sessions_dir = Self::sessions_dir()?;
72
73        if !sessions_dir.exists() {
74            anyhow::bail!("No sessions found");
75        }
76
77        let mut entries: Vec<tokio::fs::DirEntry> = Vec::new();
78        let mut read_dir = fs::read_dir(&sessions_dir).await?;
79        while let Some(entry) = read_dir.next_entry().await? {
80            entries.push(entry);
81        }
82
83        if entries.is_empty() {
84            anyhow::bail!("No sessions found");
85        }
86
87        // Sort by modification time (most recent first)
88        // Use std::fs::metadata since we can't await in sort_by_key
89        entries.sort_by_key(|e| {
90            std::cmp::Reverse(
91                std::fs::metadata(e.path())
92                    .ok()
93                    .and_then(|m| m.modified().ok())
94                    .unwrap_or(std::time::SystemTime::UNIX_EPOCH),
95            )
96        });
97
98        if let Some(entry) = entries.first() {
99            let content: String = fs::read_to_string(entry.path()).await?;
100            let session: Session = serde_json::from_str(&content)?;
101            return Ok(session);
102        }
103
104        anyhow::bail!("No sessions found")
105    }
106
107    /// Save the session to disk
108    pub async fn save(&self) -> Result<()> {
109        let path = Self::session_path(&self.id)?;
110
111        if let Some(parent) = path.parent() {
112            fs::create_dir_all(parent).await?;
113        }
114
115        let content = serde_json::to_string_pretty(self)?;
116        fs::write(&path, content).await?;
117
118        Ok(())
119    }
120
121    /// Add a message to the session
122    pub fn add_message(&mut self, message: Message) {
123        self.messages.push(message);
124        self.updated_at = Utc::now();
125    }
126
127    /// Execute a prompt and get the result
128    pub async fn prompt(&mut self, message: &str) -> Result<SessionResult> {
129        use crate::provider::{
130            CompletionRequest, ContentPart, ProviderRegistry, Role, parse_model_string,
131        };
132
133        // Load providers from Vault
134        let registry = ProviderRegistry::from_vault().await?;
135
136        let providers = registry.list();
137        if providers.is_empty() {
138            anyhow::bail!("No providers available. Configure API keys in HashiCorp Vault.");
139        }
140
141        tracing::info!("Available providers: {:?}", providers);
142
143        // Parse model string (format: "provider/model", "provider", or just "model")
144        let (provider_name, model_id) = if let Some(ref model_str) = self.metadata.model {
145            let (prov, model) = parse_model_string(model_str);
146            if prov.is_some() {
147                // Format: provider/model
148                (prov.map(|s| s.to_string()), model.to_string())
149            } else if providers.contains(&model) {
150                // Format: just provider name (e.g., "novita")
151                (Some(model.to_string()), String::new())
152            } else {
153                // Format: just model name
154                (None, model.to_string())
155            }
156        } else {
157            (None, String::new())
158        };
159
160        // Determine which provider to use (prefer zhipuai as default)
161        let selected_provider = provider_name
162            .as_deref()
163            .filter(|p| providers.contains(p))
164            .unwrap_or_else(|| {
165                if providers.contains(&"zhipuai") {
166                    "zhipuai"
167                } else {
168                    providers[0]
169                }
170            });
171
172        let provider = registry
173            .get(selected_provider)
174            .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
175
176        // Add user message to session using add_message
177        self.add_message(Message {
178            role: Role::User,
179            content: vec![ContentPart::Text {
180                text: message.to_string(),
181            }],
182        });
183
184        // Generate title if this is the first user message and no title exists
185        if self.title.is_none() {
186            self.generate_title().await?;
187        }
188
189        // Determine model to use
190        let model = if !model_id.is_empty() {
191            model_id
192        } else {
193            // Default models per provider
194            match selected_provider {
195                "moonshotai" => "kimi-k2.5".to_string(),
196                "anthropic" => "claude-sonnet-4-20250514".to_string(),
197                "openai" => "gpt-4o".to_string(),
198                "google" => "gemini-2.5-pro".to_string(),
199                "zhipuai" => "glm-4.7".to_string(),
200                "openrouter" => "zhipuai/glm-4.7".to_string(),
201                "novita" => "qwen/qwen3-coder-next".to_string(),
202                _ => "glm-4.7".to_string(),
203            }
204        };
205
206        // Create tool registry with all available tools
207        let tool_registry = ToolRegistry::with_provider_arc(Arc::clone(&provider), model.clone());
208        let tool_definitions = tool_registry.definitions();
209
210        // Kimi K2.5 requires temperature=1.0
211        let temperature = if model.starts_with("kimi-k2") {
212            Some(1.0)
213        } else {
214            Some(0.7)
215        };
216
217        tracing::info!("Using model: {} via provider: {}", model, selected_provider);
218        tracing::info!("Available tools: {}", tool_definitions.len());
219
220        // Build system prompt with AGENTS.md
221        let cwd = self
222            .metadata
223            .directory
224            .clone()
225            .unwrap_or_else(|| std::env::current_dir().unwrap_or_default());
226        let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
227
228        // Run agentic loop with tool execution
229        let max_steps = 50;
230        let mut final_output = String::new();
231
232        for step in 1..=max_steps {
233            tracing::info!(step = step, "Agent step starting");
234
235            // Build messages with system prompt first
236            let mut messages = vec![Message {
237                role: Role::System,
238                content: vec![ContentPart::Text {
239                    text: system_prompt.clone(),
240                }],
241            }];
242            messages.extend(self.messages.clone());
243
244            // Create completion request with tools
245            let request = CompletionRequest {
246                messages,
247                tools: tool_definitions.clone(),
248                model: model.clone(),
249                temperature,
250                top_p: None,
251                max_tokens: Some(8192),
252                stop: Vec::new(),
253            };
254
255            // Call the provider
256            let response = provider.complete(request).await?;
257
258            // Record token usage
259            crate::telemetry::TOKEN_USAGE.record_model_usage(
260                &model,
261                response.usage.prompt_tokens as u64,
262                response.usage.completion_tokens as u64,
263            );
264
265            // Extract tool calls from response
266            let tool_calls: Vec<(String, String, serde_json::Value)> = response
267                .message
268                .content
269                .iter()
270                .filter_map(|part| {
271                    if let ContentPart::ToolCall {
272                        id,
273                        name,
274                        arguments,
275                    } = part
276                    {
277                        // Parse arguments JSON string into Value
278                        let args: serde_json::Value =
279                            serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
280                        Some((id.clone(), name.clone(), args))
281                    } else {
282                        None
283                    }
284                })
285                .collect();
286
287            // Collect text output
288            for part in &response.message.content {
289                if let ContentPart::Text { text } = part {
290                    if !text.is_empty() {
291                        final_output.push_str(text);
292                        final_output.push('\n');
293                    }
294                }
295            }
296
297            // If no tool calls, we're done
298            if tool_calls.is_empty() {
299                self.add_message(response.message.clone());
300                break;
301            }
302
303            // Add assistant message with tool calls
304            self.add_message(response.message.clone());
305
306            tracing::info!(
307                step = step,
308                num_tools = tool_calls.len(),
309                "Executing tool calls"
310            );
311
312            // Execute each tool call
313            for (tool_id, tool_name, tool_input) in tool_calls {
314                tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
315
316                // Get and execute the tool
317                let content = if let Some(tool) = tool_registry.get(&tool_name) {
318                    match tool.execute(tool_input.clone()).await {
319                        Ok(result) => {
320                            tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
321                            result.output
322                        }
323                        Err(e) => {
324                            tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
325                            format!("Error: {}", e)
326                        }
327                    }
328                } else {
329                    tracing::warn!(tool = %tool_name, "Tool not found");
330                    format!("Error: Unknown tool '{}'", tool_name)
331                };
332
333                // Add tool result message
334                self.add_message(Message {
335                    role: Role::Tool,
336                    content: vec![ContentPart::ToolResult {
337                        tool_call_id: tool_id,
338                        content,
339                    }],
340                });
341            }
342        }
343
344        // Save session after each prompt to persist messages
345        self.save().await?;
346
347        Ok(SessionResult {
348            text: final_output.trim().to_string(),
349            session_id: self.id.clone(),
350        })
351    }
352
353    /// Process a user message with real-time event streaming for UI updates.
354    /// Events are sent through the provided channel as tool calls execute.
355    pub async fn prompt_with_events(
356        &mut self,
357        message: &str,
358        event_tx: tokio::sync::mpsc::Sender<SessionEvent>,
359    ) -> Result<SessionResult> {
360        use crate::provider::{CompletionRequest, ContentPart, ProviderRegistry, Role, parse_model_string};
361
362        let _ = event_tx.send(SessionEvent::Thinking).await;
363
364        // Load provider registry from Vault
365        let registry = ProviderRegistry::from_vault().await?;
366        let providers = registry.list();
367        tracing::info!("Available providers: {:?}", providers);
368
369        // Parse model string (format: "provider/model", "provider", or just "model")
370        let (provider_name, model_id) = if let Some(ref model_str) = self.metadata.model {
371            let (prov, model) = parse_model_string(model_str);
372            if prov.is_some() {
373                (prov.map(|s| s.to_string()), model.to_string())
374            } else if providers.contains(&model) {
375                (Some(model.to_string()), String::new())
376            } else {
377                (None, model.to_string())
378            }
379        } else {
380            (None, String::new())
381        };
382
383        // Determine which provider to use (prefer zhipuai as default)
384        let selected_provider = provider_name
385            .as_deref()
386            .filter(|p| providers.contains(p))
387            .unwrap_or_else(|| {
388                if providers.contains(&"zhipuai") {
389                    "zhipuai"
390                } else {
391                    providers[0]
392                }
393            });
394
395        let provider = registry
396            .get(selected_provider)
397            .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
398
399        // Add user message
400        self.add_message(Message {
401            role: Role::User,
402            content: vec![ContentPart::Text {
403                text: message.to_string(),
404            }],
405        });
406
407        // Generate title if needed
408        if self.title.is_none() {
409            self.generate_title().await?;
410        }
411
412        // Determine model
413        let model = if !model_id.is_empty() {
414            model_id
415        } else {
416            match selected_provider {
417                "moonshotai" => "kimi-k2.5".to_string(),
418                "anthropic" => "claude-sonnet-4-20250514".to_string(),
419                "openai" => "gpt-4o".to_string(),
420                "google" => "gemini-2.5-pro".to_string(),
421                "zhipuai" => "glm-4.7".to_string(),
422                "openrouter" => "zhipuai/glm-4.7".to_string(),
423                "novita" => "qwen/qwen3-coder-next".to_string(),
424                _ => "glm-4.7".to_string(),
425            }
426        };
427
428        // Create tool registry
429        let tool_registry = ToolRegistry::with_provider_arc(Arc::clone(&provider), model.clone());
430        let tool_definitions = tool_registry.definitions();
431
432        let temperature = if model.starts_with("kimi-k2") {
433            Some(1.0)
434        } else {
435            Some(0.7)
436        };
437
438        tracing::info!("Using model: {} via provider: {}", model, selected_provider);
439        tracing::info!("Available tools: {}", tool_definitions.len());
440
441        // Build system prompt
442        let cwd = std::env::var("PWD")
443            .map(std::path::PathBuf::from)
444            .unwrap_or_else(|_| std::env::current_dir().unwrap_or_default());
445        let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
446
447        let mut final_output = String::new();
448        let max_steps = 50;
449
450        for step in 1..=max_steps {
451            tracing::info!(step = step, "Agent step starting");
452            let _ = event_tx.send(SessionEvent::Thinking).await;
453
454            // Build messages with system prompt first
455            let mut messages = vec![Message {
456                role: Role::System,
457                content: vec![ContentPart::Text {
458                    text: system_prompt.clone(),
459                }],
460            }];
461            messages.extend(self.messages.clone());
462
463            let request = CompletionRequest {
464                messages,
465                tools: tool_definitions.clone(),
466                model: model.clone(),
467                temperature,
468                top_p: None,
469                max_tokens: Some(8192),
470                stop: Vec::new(),
471            };
472
473            let response = provider.complete(request).await?;
474
475            crate::telemetry::TOKEN_USAGE.record_model_usage(
476                &model,
477                response.usage.prompt_tokens as u64,
478                response.usage.completion_tokens as u64,
479            );
480
481            // Extract tool calls
482            let tool_calls: Vec<(String, String, serde_json::Value)> = response
483                .message
484                .content
485                .iter()
486                .filter_map(|part| {
487                    if let ContentPart::ToolCall { id, name, arguments } = part {
488                        let args: serde_json::Value =
489                            serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
490                        Some((id.clone(), name.clone(), args))
491                    } else {
492                        None
493                    }
494                })
495                .collect();
496
497            // Collect text output
498            for part in &response.message.content {
499                if let ContentPart::Text { text } = part {
500                    if !text.is_empty() {
501                        final_output.push_str(text);
502                        final_output.push('\n');
503                        let _ = event_tx.send(SessionEvent::TextChunk(text.clone())).await;
504                    }
505                }
506            }
507
508            if tool_calls.is_empty() {
509                self.add_message(response.message.clone());
510                break;
511            }
512
513            self.add_message(response.message.clone());
514
515            tracing::info!(step = step, num_tools = tool_calls.len(), "Executing tool calls");
516
517            // Execute each tool call with events
518            for (tool_id, tool_name, tool_input) in tool_calls {
519                let args_str = serde_json::to_string(&tool_input).unwrap_or_default();
520                let _ = event_tx
521                    .send(SessionEvent::ToolCallStart {
522                        name: tool_name.clone(),
523                        arguments: args_str,
524                    })
525                    .await;
526
527                tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
528
529                let (content, success) = if let Some(tool) = tool_registry.get(&tool_name) {
530                    match tool.execute(tool_input.clone()).await {
531                        Ok(result) => {
532                            tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
533                            (result.output, result.success)
534                        }
535                        Err(e) => {
536                            tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
537                            (format!("Error: {}", e), false)
538                        }
539                    }
540                } else {
541                    tracing::warn!(tool = %tool_name, "Tool not found");
542                    (format!("Error: Unknown tool '{}'", tool_name), false)
543                };
544
545                let _ = event_tx
546                    .send(SessionEvent::ToolCallComplete {
547                        name: tool_name.clone(),
548                        output: content.clone(),
549                        success,
550                    })
551                    .await;
552
553                self.add_message(Message {
554                    role: Role::Tool,
555                    content: vec![ContentPart::ToolResult {
556                        tool_call_id: tool_id,
557                        content,
558                    }],
559                });
560            }
561        }
562
563        self.save().await?;
564
565        let _ = event_tx
566            .send(SessionEvent::TextComplete(final_output.trim().to_string()))
567            .await;
568        let _ = event_tx.send(SessionEvent::Done).await;
569
570        Ok(SessionResult {
571            text: final_output.trim().to_string(),
572            session_id: self.id.clone(),
573        })
574    }
575
576    /// Generate a title for the session based on the first message
577    /// Only sets title if not already set (for initial title generation)
578    pub async fn generate_title(&mut self) -> Result<()> {
579        if self.title.is_some() {
580            return Ok(());
581        }
582
583        // Get first user message
584        let first_message = self
585            .messages
586            .iter()
587            .find(|m| m.role == crate::provider::Role::User);
588
589        if let Some(msg) = first_message {
590            let text: String = msg
591                .content
592                .iter()
593                .filter_map(|p| match p {
594                    crate::provider::ContentPart::Text { text } => Some(text.clone()),
595                    _ => None,
596                })
597                .collect::<Vec<_>>()
598                .join(" ");
599
600            // Truncate to reasonable length
601            self.title = Some(if text.len() > 50 {
602                format!("{}...", &text[..47])
603            } else {
604                text
605            });
606        }
607
608        Ok(())
609    }
610
611    /// Regenerate the title based on the first message, even if already set
612    /// Use this for on-demand title updates or after context changes
613    pub async fn regenerate_title(&mut self) -> Result<()> {
614        // Get first user message
615        let first_message = self
616            .messages
617            .iter()
618            .find(|m| m.role == crate::provider::Role::User);
619
620        if let Some(msg) = first_message {
621            let text: String = msg
622                .content
623                .iter()
624                .filter_map(|p| match p {
625                    crate::provider::ContentPart::Text { text } => Some(text.clone()),
626                    _ => None,
627                })
628                .collect::<Vec<_>>()
629                .join(" ");
630
631            // Truncate to reasonable length
632            self.title = Some(if text.len() > 50 {
633                format!("{}...", &text[..47])
634            } else {
635                text
636            });
637        }
638
639        Ok(())
640    }
641
642    /// Set a custom title for the session
643    pub fn set_title(&mut self, title: impl Into<String>) {
644        self.title = Some(title.into());
645        self.updated_at = Utc::now();
646    }
647
648    /// Clear the title, allowing it to be regenerated
649    pub fn clear_title(&mut self) {
650        self.title = None;
651        self.updated_at = Utc::now();
652    }
653
654    /// Handle context change - updates metadata and optionally regenerates title
655    /// Call this when the session context changes (e.g., directory change, model change)
656    pub async fn on_context_change(&mut self, regenerate_title: bool) -> Result<()> {
657        self.updated_at = Utc::now();
658
659        if regenerate_title {
660            self.regenerate_title().await?;
661        }
662
663        Ok(())
664    }
665
666    /// Get the sessions directory
667    fn sessions_dir() -> Result<PathBuf> {
668        crate::config::Config::data_dir()
669            .map(|d| d.join("sessions"))
670            .ok_or_else(|| anyhow::anyhow!("Could not determine data directory"))
671    }
672
673    /// Get the path for a session file
674    fn session_path(id: &str) -> Result<PathBuf> {
675        Ok(Self::sessions_dir()?.join(format!("{}.json", id)))
676    }
677}
678
679/// Result from a session prompt
680#[derive(Debug, Clone, Serialize, Deserialize)]
681pub struct SessionResult {
682    pub text: String,
683    pub session_id: String,
684}
685
686/// Events emitted during session processing for real-time UI updates
687#[derive(Debug, Clone)]
688pub enum SessionEvent {
689    /// Agent is thinking/processing
690    Thinking,
691    /// Tool call started
692    ToolCallStart { name: String, arguments: String },
693    /// Tool call completed with result
694    ToolCallComplete { name: String, output: String, success: bool },
695    /// Partial text output (for streaming)
696    TextChunk(String),
697    /// Final text output
698    TextComplete(String),
699    /// Processing complete
700    Done,
701    /// Error occurred
702    Error(String),
703}
704
705/// List all sessions
706pub async fn list_sessions() -> Result<Vec<SessionSummary>> {
707    let sessions_dir = crate::config::Config::data_dir()
708        .map(|d| d.join("sessions"))
709        .ok_or_else(|| anyhow::anyhow!("Could not determine data directory"))?;
710
711    if !sessions_dir.exists() {
712        return Ok(Vec::new());
713    }
714
715    let mut summaries = Vec::new();
716    let mut entries = fs::read_dir(&sessions_dir).await?;
717
718    while let Some(entry) = entries.next_entry().await? {
719        let path = entry.path();
720        if path.extension().map(|e| e == "json").unwrap_or(false) {
721            if let Ok(content) = fs::read_to_string(&path).await {
722                if let Ok(session) = serde_json::from_str::<Session>(&content) {
723                    summaries.push(SessionSummary {
724                        id: session.id,
725                        title: session.title,
726                        created_at: session.created_at,
727                        updated_at: session.updated_at,
728                        message_count: session.messages.len(),
729                        agent: session.agent,
730                    });
731                }
732            }
733        }
734    }
735
736    summaries.sort_by(|a, b| b.updated_at.cmp(&a.updated_at));
737    Ok(summaries)
738}
739
740/// Summary of a session for listing
741#[derive(Debug, Clone, Serialize, Deserialize)]
742pub struct SessionSummary {
743    pub id: String,
744    pub title: Option<String>,
745    pub created_at: DateTime<Utc>,
746    pub updated_at: DateTime<Utc>,
747    pub message_count: usize,
748    pub agent: String,
749}
750
751// Async helper for Vec - kept for potential future use
752#[allow(dead_code)]
753use futures::StreamExt;
754
755#[allow(dead_code)]
756trait AsyncCollect<T> {
757    async fn collect(self) -> Vec<T>;
758}
759
760#[allow(dead_code)]
761impl<S, T> AsyncCollect<T> for S
762where
763    S: futures::Stream<Item = T> + Unpin,
764{
765    async fn collect(mut self) -> Vec<T> {
766        let mut items = Vec::new();
767        while let Some(item) = self.next().await {
768            items.push(item);
769        }
770        items
771    }
772}