1use crate::agent::ToolUse;
6use crate::provider::{Message, Usage};
7use crate::tool::ToolRegistry;
8use anyhow::Result;
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::path::PathBuf;
12use std::sync::Arc;
13use tokio::fs;
14use uuid::Uuid;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct Session {
19 pub id: String,
20 pub title: Option<String>,
21 pub created_at: DateTime<Utc>,
22 pub updated_at: DateTime<Utc>,
23 pub messages: Vec<Message>,
24 pub tool_uses: Vec<ToolUse>,
25 pub usage: Usage,
26 pub agent: String,
27 pub metadata: SessionMetadata,
28}
29
30#[derive(Debug, Clone, Default, Serialize, Deserialize)]
31pub struct SessionMetadata {
32 pub directory: Option<PathBuf>,
33 pub model: Option<String>,
34 pub shared: bool,
35 pub share_url: Option<String>,
36}
37
38impl Session {
39 fn default_model_for_provider(provider: &str) -> String {
40 match provider {
41 "moonshotai" => "kimi-k2.5".to_string(),
42 "anthropic" => "claude-sonnet-4-20250514".to_string(),
43 "openai" => "gpt-4o".to_string(),
44 "google" => "gemini-2.5-pro".to_string(),
45 "zhipuai" => "glm-4.7".to_string(),
46 "openrouter" => "zhipuai/glm-4.7".to_string(),
47 "novita" => "qwen/qwen3-coder-next".to_string(),
48 "github-copilot" | "github-copilot-enterprise" => "gpt-5-mini".to_string(),
49 _ => "glm-4.7".to_string(),
50 }
51 }
52
53 pub async fn new() -> Result<Self> {
55 let id = Uuid::new_v4().to_string();
56 let now = Utc::now();
57
58 Ok(Self {
59 id,
60 title: None,
61 created_at: now,
62 updated_at: now,
63 messages: Vec::new(),
64 tool_uses: Vec::new(),
65 usage: Usage::default(),
66 agent: "build".to_string(),
67 metadata: SessionMetadata {
68 directory: Some(std::env::current_dir()?),
69 ..Default::default()
70 },
71 })
72 }
73
74 pub async fn load(id: &str) -> Result<Self> {
76 let path = Self::session_path(id)?;
77 let content = fs::read_to_string(&path).await?;
78 let session: Session = serde_json::from_str(&content)?;
79 Ok(session)
80 }
81
82 pub async fn last() -> Result<Self> {
84 let sessions_dir = Self::sessions_dir()?;
85
86 if !sessions_dir.exists() {
87 anyhow::bail!("No sessions found");
88 }
89
90 let mut entries: Vec<tokio::fs::DirEntry> = Vec::new();
91 let mut read_dir = fs::read_dir(&sessions_dir).await?;
92 while let Some(entry) = read_dir.next_entry().await? {
93 entries.push(entry);
94 }
95
96 if entries.is_empty() {
97 anyhow::bail!("No sessions found");
98 }
99
100 entries.sort_by_key(|e| {
103 std::cmp::Reverse(
104 std::fs::metadata(e.path())
105 .ok()
106 .and_then(|m| m.modified().ok())
107 .unwrap_or(std::time::SystemTime::UNIX_EPOCH),
108 )
109 });
110
111 if let Some(entry) = entries.first() {
112 let content: String = fs::read_to_string(entry.path()).await?;
113 let session: Session = serde_json::from_str(&content)?;
114 return Ok(session);
115 }
116
117 anyhow::bail!("No sessions found")
118 }
119
120 pub async fn save(&self) -> Result<()> {
122 let path = Self::session_path(&self.id)?;
123
124 if let Some(parent) = path.parent() {
125 fs::create_dir_all(parent).await?;
126 }
127
128 let content = serde_json::to_string_pretty(self)?;
129 fs::write(&path, content).await?;
130
131 Ok(())
132 }
133
134 pub fn add_message(&mut self, message: Message) {
136 self.messages.push(message);
137 self.updated_at = Utc::now();
138 }
139
140 pub async fn prompt(&mut self, message: &str) -> Result<SessionResult> {
142 use crate::provider::{
143 CompletionRequest, ContentPart, ProviderRegistry, Role, parse_model_string,
144 };
145
146 let registry = ProviderRegistry::from_vault().await?;
148
149 let providers = registry.list();
150 if providers.is_empty() {
151 anyhow::bail!(
152 "No providers available. Configure API keys in HashiCorp Vault (for Copilot use `codetether auth copilot`)."
153 );
154 }
155
156 tracing::info!("Available providers: {:?}", providers);
157
158 let (provider_name, model_id) = if let Some(ref model_str) = self.metadata.model {
160 let (prov, model) = parse_model_string(model_str);
161 if prov.is_some() {
162 (prov.map(|s| s.to_string()), model.to_string())
164 } else if providers.contains(&model) {
165 (Some(model.to_string()), String::new())
167 } else {
168 (None, model.to_string())
170 }
171 } else {
172 (None, String::new())
173 };
174
175 let selected_provider = provider_name
177 .as_deref()
178 .filter(|p| providers.contains(p))
179 .or_else(|| {
180 if providers.contains(&"zhipuai") {
181 Some("zhipuai")
182 } else {
183 providers.first().copied()
184 }
185 })
186 .ok_or_else(|| anyhow::anyhow!("No providers available"))?;
187
188 let provider = registry
189 .get(selected_provider)
190 .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
191
192 self.add_message(Message {
194 role: Role::User,
195 content: vec![ContentPart::Text {
196 text: message.to_string(),
197 }],
198 });
199
200 if self.title.is_none() {
202 self.generate_title().await?;
203 }
204
205 let model = if !model_id.is_empty() {
207 model_id
208 } else {
209 Self::default_model_for_provider(selected_provider)
210 };
211
212 let tool_registry = ToolRegistry::with_provider_arc(Arc::clone(&provider), model.clone());
214 let tool_definitions = tool_registry.definitions();
215
216 let temperature = if model.starts_with("kimi-k2") {
218 Some(1.0)
219 } else {
220 Some(0.7)
221 };
222
223 tracing::info!("Using model: {} via provider: {}", model, selected_provider);
224 tracing::info!("Available tools: {}", tool_definitions.len());
225
226 let cwd = self
228 .metadata
229 .directory
230 .clone()
231 .unwrap_or_else(|| std::env::current_dir().unwrap_or_default());
232 let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
233
234 let max_steps = 50;
236 let mut final_output = String::new();
237
238 for step in 1..=max_steps {
239 tracing::info!(step = step, "Agent step starting");
240
241 let mut messages = vec![Message {
243 role: Role::System,
244 content: vec![ContentPart::Text {
245 text: system_prompt.clone(),
246 }],
247 }];
248 messages.extend(self.messages.clone());
249
250 let request = CompletionRequest {
252 messages,
253 tools: tool_definitions.clone(),
254 model: model.clone(),
255 temperature,
256 top_p: None,
257 max_tokens: Some(8192),
258 stop: Vec::new(),
259 };
260
261 let response = provider.complete(request).await?;
263
264 crate::telemetry::TOKEN_USAGE.record_model_usage(
266 &model,
267 response.usage.prompt_tokens as u64,
268 response.usage.completion_tokens as u64,
269 );
270
271 let tool_calls: Vec<(String, String, serde_json::Value)> = response
273 .message
274 .content
275 .iter()
276 .filter_map(|part| {
277 if let ContentPart::ToolCall {
278 id,
279 name,
280 arguments,
281 } = part
282 {
283 let args: serde_json::Value =
285 serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
286 Some((id.clone(), name.clone(), args))
287 } else {
288 None
289 }
290 })
291 .collect();
292
293 for part in &response.message.content {
295 if let ContentPart::Text { text } = part {
296 if !text.is_empty() {
297 final_output.push_str(text);
298 final_output.push('\n');
299 }
300 }
301 }
302
303 if tool_calls.is_empty() {
305 self.add_message(response.message.clone());
306 break;
307 }
308
309 self.add_message(response.message.clone());
311
312 tracing::info!(
313 step = step,
314 num_tools = tool_calls.len(),
315 "Executing tool calls"
316 );
317
318 for (tool_id, tool_name, tool_input) in tool_calls {
320 tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
321
322 let content = if let Some(tool) = tool_registry.get(&tool_name) {
324 match tool.execute(tool_input.clone()).await {
325 Ok(result) => {
326 tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
327 result.output
328 }
329 Err(e) => {
330 tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
331 format!("Error: {}", e)
332 }
333 }
334 } else {
335 tracing::warn!(tool = %tool_name, "Tool not found");
336 format!("Error: Unknown tool '{}'", tool_name)
337 };
338
339 self.add_message(Message {
341 role: Role::Tool,
342 content: vec![ContentPart::ToolResult {
343 tool_call_id: tool_id,
344 content,
345 }],
346 });
347 }
348 }
349
350 self.save().await?;
352
353 Ok(SessionResult {
354 text: final_output.trim().to_string(),
355 session_id: self.id.clone(),
356 })
357 }
358
359 pub async fn prompt_with_events(
362 &mut self,
363 message: &str,
364 event_tx: tokio::sync::mpsc::Sender<SessionEvent>,
365 ) -> Result<SessionResult> {
366 use crate::provider::{
367 CompletionRequest, ContentPart, ProviderRegistry, Role, parse_model_string,
368 };
369
370 let _ = event_tx.send(SessionEvent::Thinking).await;
371
372 let registry = ProviderRegistry::from_vault().await?;
374 let providers = registry.list();
375 if providers.is_empty() {
376 anyhow::bail!(
377 "No providers available. Configure API keys in HashiCorp Vault (for Copilot use `codetether auth copilot`)."
378 );
379 }
380 tracing::info!("Available providers: {:?}", providers);
381
382 let (provider_name, model_id) = if let Some(ref model_str) = self.metadata.model {
384 let (prov, model) = parse_model_string(model_str);
385 if prov.is_some() {
386 (prov.map(|s| s.to_string()), model.to_string())
387 } else if providers.contains(&model) {
388 (Some(model.to_string()), String::new())
389 } else {
390 (None, model.to_string())
391 }
392 } else {
393 (None, String::new())
394 };
395
396 let selected_provider = provider_name
398 .as_deref()
399 .filter(|p| providers.contains(p))
400 .or_else(|| {
401 if providers.contains(&"zhipuai") {
402 Some("zhipuai")
403 } else {
404 providers.first().copied()
405 }
406 })
407 .ok_or_else(|| anyhow::anyhow!("No providers available"))?;
408
409 let provider = registry
410 .get(selected_provider)
411 .ok_or_else(|| anyhow::anyhow!("Provider {} not found", selected_provider))?;
412
413 self.add_message(Message {
415 role: Role::User,
416 content: vec![ContentPart::Text {
417 text: message.to_string(),
418 }],
419 });
420
421 if self.title.is_none() {
423 self.generate_title().await?;
424 }
425
426 let model = if !model_id.is_empty() {
428 model_id
429 } else {
430 Self::default_model_for_provider(selected_provider)
431 };
432
433 let tool_registry = ToolRegistry::with_provider_arc(Arc::clone(&provider), model.clone());
435 let tool_definitions = tool_registry.definitions();
436
437 let temperature = if model.starts_with("kimi-k2") {
438 Some(1.0)
439 } else {
440 Some(0.7)
441 };
442
443 tracing::info!("Using model: {} via provider: {}", model, selected_provider);
444 tracing::info!("Available tools: {}", tool_definitions.len());
445
446 let cwd = std::env::var("PWD")
448 .map(std::path::PathBuf::from)
449 .unwrap_or_else(|_| std::env::current_dir().unwrap_or_default());
450 let system_prompt = crate::agent::builtin::build_system_prompt(&cwd);
451
452 let mut final_output = String::new();
453 let max_steps = 50;
454
455 for step in 1..=max_steps {
456 tracing::info!(step = step, "Agent step starting");
457 let _ = event_tx.send(SessionEvent::Thinking).await;
458
459 let mut messages = vec![Message {
461 role: Role::System,
462 content: vec![ContentPart::Text {
463 text: system_prompt.clone(),
464 }],
465 }];
466 messages.extend(self.messages.clone());
467
468 let request = CompletionRequest {
469 messages,
470 tools: tool_definitions.clone(),
471 model: model.clone(),
472 temperature,
473 top_p: None,
474 max_tokens: Some(8192),
475 stop: Vec::new(),
476 };
477
478 let response = provider.complete(request).await?;
479
480 crate::telemetry::TOKEN_USAGE.record_model_usage(
481 &model,
482 response.usage.prompt_tokens as u64,
483 response.usage.completion_tokens as u64,
484 );
485
486 let tool_calls: Vec<(String, String, serde_json::Value)> = response
488 .message
489 .content
490 .iter()
491 .filter_map(|part| {
492 if let ContentPart::ToolCall {
493 id,
494 name,
495 arguments,
496 } = part
497 {
498 let args: serde_json::Value =
499 serde_json::from_str(arguments).unwrap_or(serde_json::json!({}));
500 Some((id.clone(), name.clone(), args))
501 } else {
502 None
503 }
504 })
505 .collect();
506
507 for part in &response.message.content {
509 if let ContentPart::Text { text } = part {
510 if !text.is_empty() {
511 final_output.push_str(text);
512 final_output.push('\n');
513 let _ = event_tx.send(SessionEvent::TextChunk(text.clone())).await;
514 }
515 }
516 }
517
518 if tool_calls.is_empty() {
519 self.add_message(response.message.clone());
520 break;
521 }
522
523 self.add_message(response.message.clone());
524
525 tracing::info!(
526 step = step,
527 num_tools = tool_calls.len(),
528 "Executing tool calls"
529 );
530
531 for (tool_id, tool_name, tool_input) in tool_calls {
533 let args_str = serde_json::to_string(&tool_input).unwrap_or_default();
534 let _ = event_tx
535 .send(SessionEvent::ToolCallStart {
536 name: tool_name.clone(),
537 arguments: args_str,
538 })
539 .await;
540
541 tracing::info!(tool = %tool_name, tool_id = %tool_id, "Executing tool");
542
543 let (content, success) = if let Some(tool) = tool_registry.get(&tool_name) {
544 match tool.execute(tool_input.clone()).await {
545 Ok(result) => {
546 tracing::info!(tool = %tool_name, success = result.success, "Tool execution completed");
547 (result.output, result.success)
548 }
549 Err(e) => {
550 tracing::warn!(tool = %tool_name, error = %e, "Tool execution failed");
551 (format!("Error: {}", e), false)
552 }
553 }
554 } else {
555 tracing::warn!(tool = %tool_name, "Tool not found");
556 (format!("Error: Unknown tool '{}'", tool_name), false)
557 };
558
559 let _ = event_tx
560 .send(SessionEvent::ToolCallComplete {
561 name: tool_name.clone(),
562 output: content.clone(),
563 success,
564 })
565 .await;
566
567 self.add_message(Message {
568 role: Role::Tool,
569 content: vec![ContentPart::ToolResult {
570 tool_call_id: tool_id,
571 content,
572 }],
573 });
574 }
575 }
576
577 self.save().await?;
578
579 let _ = event_tx
580 .send(SessionEvent::TextComplete(final_output.trim().to_string()))
581 .await;
582 let _ = event_tx.send(SessionEvent::Done).await;
583
584 Ok(SessionResult {
585 text: final_output.trim().to_string(),
586 session_id: self.id.clone(),
587 })
588 }
589
590 pub async fn generate_title(&mut self) -> Result<()> {
593 if self.title.is_some() {
594 return Ok(());
595 }
596
597 let first_message = self
599 .messages
600 .iter()
601 .find(|m| m.role == crate::provider::Role::User);
602
603 if let Some(msg) = first_message {
604 let text: String = msg
605 .content
606 .iter()
607 .filter_map(|p| match p {
608 crate::provider::ContentPart::Text { text } => Some(text.clone()),
609 _ => None,
610 })
611 .collect::<Vec<_>>()
612 .join(" ");
613
614 self.title = Some(if text.len() > 50 {
616 format!("{}...", &text[..47])
617 } else {
618 text
619 });
620 }
621
622 Ok(())
623 }
624
625 pub async fn regenerate_title(&mut self) -> Result<()> {
628 let first_message = self
630 .messages
631 .iter()
632 .find(|m| m.role == crate::provider::Role::User);
633
634 if let Some(msg) = first_message {
635 let text: String = msg
636 .content
637 .iter()
638 .filter_map(|p| match p {
639 crate::provider::ContentPart::Text { text } => Some(text.clone()),
640 _ => None,
641 })
642 .collect::<Vec<_>>()
643 .join(" ");
644
645 self.title = Some(if text.len() > 50 {
647 format!("{}...", &text[..47])
648 } else {
649 text
650 });
651 }
652
653 Ok(())
654 }
655
656 pub fn set_title(&mut self, title: impl Into<String>) {
658 self.title = Some(title.into());
659 self.updated_at = Utc::now();
660 }
661
662 pub fn clear_title(&mut self) {
664 self.title = None;
665 self.updated_at = Utc::now();
666 }
667
668 pub async fn on_context_change(&mut self, regenerate_title: bool) -> Result<()> {
671 self.updated_at = Utc::now();
672
673 if regenerate_title {
674 self.regenerate_title().await?;
675 }
676
677 Ok(())
678 }
679
680 fn sessions_dir() -> Result<PathBuf> {
682 crate::config::Config::data_dir()
683 .map(|d| d.join("sessions"))
684 .ok_or_else(|| anyhow::anyhow!("Could not determine data directory"))
685 }
686
687 fn session_path(id: &str) -> Result<PathBuf> {
689 Ok(Self::sessions_dir()?.join(format!("{}.json", id)))
690 }
691}
692
693#[derive(Debug, Clone, Serialize, Deserialize)]
695pub struct SessionResult {
696 pub text: String,
697 pub session_id: String,
698}
699
700#[derive(Debug, Clone)]
702pub enum SessionEvent {
703 Thinking,
705 ToolCallStart { name: String, arguments: String },
707 ToolCallComplete {
709 name: String,
710 output: String,
711 success: bool,
712 },
713 TextChunk(String),
715 TextComplete(String),
717 Done,
719 Error(String),
721}
722
723pub async fn list_sessions() -> Result<Vec<SessionSummary>> {
725 let sessions_dir = crate::config::Config::data_dir()
726 .map(|d| d.join("sessions"))
727 .ok_or_else(|| anyhow::anyhow!("Could not determine data directory"))?;
728
729 if !sessions_dir.exists() {
730 return Ok(Vec::new());
731 }
732
733 let mut summaries = Vec::new();
734 let mut entries = fs::read_dir(&sessions_dir).await?;
735
736 while let Some(entry) = entries.next_entry().await? {
737 let path = entry.path();
738 if path.extension().map(|e| e == "json").unwrap_or(false) {
739 if let Ok(content) = fs::read_to_string(&path).await {
740 if let Ok(session) = serde_json::from_str::<Session>(&content) {
741 summaries.push(SessionSummary {
742 id: session.id,
743 title: session.title,
744 created_at: session.created_at,
745 updated_at: session.updated_at,
746 message_count: session.messages.len(),
747 agent: session.agent,
748 });
749 }
750 }
751 }
752 }
753
754 summaries.sort_by(|a, b| b.updated_at.cmp(&a.updated_at));
755 Ok(summaries)
756}
757
758#[derive(Debug, Clone, Serialize, Deserialize)]
760pub struct SessionSummary {
761 pub id: String,
762 pub title: Option<String>,
763 pub created_at: DateTime<Utc>,
764 pub updated_at: DateTime<Utc>,
765 pub message_count: usize,
766 pub agent: String,
767}
768
769#[allow(dead_code)]
771use futures::StreamExt;
772
773#[allow(dead_code)]
774trait AsyncCollect<T> {
775 async fn collect(self) -> Vec<T>;
776}
777
778#[allow(dead_code)]
779impl<S, T> AsyncCollect<T> for S
780where
781 S: futures::Stream<Item = T> + Unpin,
782{
783 async fn collect(mut self) -> Vec<T> {
784 let mut items = Vec::new();
785 while let Some(item) = self.next().await {
786 items.push(item);
787 }
788 items
789 }
790}