Skip to main content

deepseek_rust_cli/agent/
agent.rs

1use std::sync::Arc;
2
3use tokio_util::sync::CancellationToken;
4use uuid::Uuid;
5
6use crate::{
7    agent::{history::load_history, types::UndoAction},
8    api::{
9        client::DeepSeekClient,
10        types::{Message, TokenUsage},
11    },
12    config::Config,
13};
14
15pub struct DeepSeekAgent {
16    pub client: Arc<DeepSeekClient>,
17    /// Tool result cache to avoid redundant operations
18    pub tool_cache: crate::agent::executor::ToolCache,
19    pub model: String,
20    pub session_id: String,
21    pub messages: Vec<Message>,
22    pub config: Config,
23    pub token_usage: TokenUsage,
24    pub undo_stack: Vec<UndoAction>,
25    pub auto_approve: bool,
26    /// Shared cancel token — can be cancelled from outside the agent mutex
27    pub cancel_token: Arc<std::sync::Mutex<CancellationToken>>,
28    pub run_id: Arc<std::sync::atomic::AtomicUsize>,
29    pub cwd: std::path::PathBuf,
30    /// Cached context detection for dynamic tool filtering
31    pub is_git_repo: bool,
32    pub has_github_token: bool,
33}
34
35impl DeepSeekAgent {
36    pub fn new(api_key: String, config: Config, session_id: Option<String>) -> Self {
37        let sid = session_id.unwrap_or_else(|| Uuid::new_v4().to_string());
38        let mut messages = load_history(&sid);
39
40        if messages.is_empty() {
41            let mut base_prompt = config.system_prompt.clone();
42            if config.concise_reasoning {
43                base_prompt.push_str(
44                    "\nKeep your internal reasoning/thinking process very short and concise.",
45                );
46            }
47            let full_sys = format!(
48                "{}\n{}",
49                base_prompt,
50                crate::agent::context::get_project_context()
51            );
52            messages.push(Message {
53                role: "system".to_string(),
54                content: Some(full_sys),
55                reasoning_content: None,
56                tool_calls: None,
57                tool_call_id: None,
58            });
59        }
60
61        Self {
62            client: Arc::new(DeepSeekClient::new(
63                api_key,
64                config.base_url.clone(),
65                config.request_timeout,
66                config.proxy_url.clone(),
67                config.proxy_username.clone(),
68                config.proxy_password.clone(),
69                config.danger_accept_invalid_certs,
70            )),
71            model: config.model.clone(),
72            session_id: sid,
73            messages,
74            config,
75            token_usage: TokenUsage::default(),
76            undo_stack: Vec::new(),
77            auto_approve: false,
78            cancel_token: Arc::new(std::sync::Mutex::new(CancellationToken::new())),
79            run_id: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
80            cwd: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
81            tool_cache: std::collections::HashMap::new(),
82            is_git_repo: std::path::Path::new(".git").exists(),
83            has_github_token: std::env::var("GITHUB_TOKEN").is_ok(),
84        }
85    }
86
87    /// Reset the cancellation token for a new request
88    pub fn reset_cancel(&mut self) {
89        let mut token = self.cancel_token.lock().unwrap_or_else(|e| e.into_inner());
90        *token = CancellationToken::new();
91    }
92
93    /// Abort the current streaming request
94    pub fn abort(&self) {
95        if let Ok(token) = self.cancel_token.lock() {
96            token.cancel();
97        } else {
98            tracing::warn!("Cancel token mutex poisoned during abort");
99        }
100    }
101
102    /// Check if cancelled (lock-free clone for use in hot loops)
103    pub fn is_cancelled(&self) -> bool {
104        self.cancel_token
105            .lock()
106            .unwrap_or_else(|e| e.into_inner())
107            .is_cancelled()
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use crate::config::Config;
115
116    #[test]
117    fn test_agent_new_context_chars() {
118        let config = Config {
119            max_context_chars: 500,
120            max_tool_output_chars: 300,
121            ..Default::default()
122        };
123        let agent = DeepSeekAgent::new("dummy_key".to_string(), config, None);
124        assert_eq!(agent.config.max_context_chars, 500);
125        assert_eq!(agent.config.max_tool_output_chars, 300);
126    }
127}