1use crate::{
2 checkpoint_store::CheckpointStore, event_log::EventLog, idempotency::IdempotencyStore,
3 sandbox::SandboxConfig, session_manager::SessionManager,
4};
5use stakpak_agent_core::{ProposedToolCall, ToolApprovalPolicy};
6use stakpak_api::SessionStorage;
7use stakpak_mcp_client::McpClient;
8use std::{collections::HashMap, sync::Arc, time::Instant};
9use tokio::sync::{RwLock, broadcast};
10use uuid::Uuid;
11
12#[derive(Debug, Clone)]
13pub struct PendingToolApprovals {
14 pub run_id: Uuid,
15 pub tool_calls: Vec<ProposedToolCall>,
16}
17
18#[derive(Clone)]
19pub struct AppState {
20 pub run_manager: SessionManager,
21 pub session_store: Arc<dyn SessionStorage>,
23 pub events: Arc<EventLog>,
24 pub idempotency: Arc<IdempotencyStore>,
25 pub inference: Arc<stakai::Inference>,
26 pub checkpoint_store: Arc<CheckpointStore>,
28 pub models: Arc<Vec<stakai::Model>>,
29 pub default_model: Option<stakai::Model>,
30 pub tool_approval_policy: ToolApprovalPolicy,
31 pub started_at: Instant,
32 pub mcp_client: Option<Arc<McpClient>>,
33 pub mcp_tools: Arc<RwLock<Vec<stakai::Tool>>>,
34 pub mcp_server_shutdown_tx: Option<broadcast::Sender<()>>,
35 pub mcp_proxy_shutdown_tx: Option<broadcast::Sender<()>>,
36 pub sandbox_config: Option<SandboxConfig>,
37 pending_tools: Arc<RwLock<HashMap<Uuid, PendingToolApprovals>>>,
38}
39
40impl AppState {
41 #[allow(clippy::too_many_arguments)]
42 pub fn new(
43 session_store: Arc<dyn SessionStorage>,
44 events: Arc<EventLog>,
45 idempotency: Arc<IdempotencyStore>,
46 inference: Arc<stakai::Inference>,
47 models: Vec<stakai::Model>,
48 default_model: Option<stakai::Model>,
49 tool_approval_policy: ToolApprovalPolicy,
50 ) -> Self {
51 Self {
52 run_manager: SessionManager::new(),
53 session_store,
54 events,
55 idempotency,
56 inference,
57 checkpoint_store: Arc::new(CheckpointStore::default_local()),
58 models: Arc::new(models),
59 default_model,
60 tool_approval_policy,
61 started_at: Instant::now(),
62 mcp_client: None,
63 mcp_tools: Arc::new(RwLock::new(Vec::new())),
64 mcp_server_shutdown_tx: None,
65 mcp_proxy_shutdown_tx: None,
66 sandbox_config: None,
67 pending_tools: Arc::new(RwLock::new(HashMap::new())),
68 }
69 }
70
71 pub fn with_mcp(
72 mut self,
73 mcp_client: Arc<McpClient>,
74 mcp_tools: Vec<stakai::Tool>,
75 mcp_server_shutdown_tx: Option<broadcast::Sender<()>>,
76 mcp_proxy_shutdown_tx: Option<broadcast::Sender<()>>,
77 ) -> Self {
78 self.mcp_client = Some(mcp_client);
79 self.mcp_tools = Arc::new(RwLock::new(mcp_tools));
80 self.mcp_server_shutdown_tx = mcp_server_shutdown_tx;
81 self.mcp_proxy_shutdown_tx = mcp_proxy_shutdown_tx;
82 self
83 }
84
85 pub fn with_sandbox(mut self, sandbox_config: SandboxConfig) -> Self {
86 self.sandbox_config = Some(sandbox_config);
87 self
88 }
89
90 pub fn with_checkpoint_store(mut self, checkpoint_store: Arc<CheckpointStore>) -> Self {
91 self.checkpoint_store = checkpoint_store;
92 self
93 }
94
95 pub async fn current_mcp_tools(&self) -> Vec<stakai::Tool> {
96 self.mcp_tools.read().await.clone()
97 }
98
99 pub async fn refresh_mcp_tools(&self) -> Result<usize, String> {
100 let Some(mcp_client) = self.mcp_client.as_ref() else {
101 return Ok(self.mcp_tools.read().await.len());
102 };
103
104 let raw_tools = stakpak_mcp_client::get_tools(mcp_client)
105 .await
106 .map_err(|error| format!("Failed to refresh MCP tools: {error}"))?;
107
108 let converted = raw_tools
109 .into_iter()
110 .map(|tool| stakai::Tool {
111 tool_type: "function".to_string(),
112 function: stakai::ToolFunction {
113 name: tool.name.as_ref().to_string(),
114 description: tool
115 .description
116 .as_ref()
117 .map(std::string::ToString::to_string)
118 .unwrap_or_default(),
119 parameters: serde_json::Value::Object((*tool.input_schema).clone()),
120 },
121 provider_options: None,
122 })
123 .collect::<Vec<_>>();
124
125 let mut guard = self.mcp_tools.write().await;
126 *guard = converted;
127 Ok(guard.len())
128 }
129
130 pub fn uptime_seconds(&self) -> u64 {
131 self.started_at.elapsed().as_secs()
132 }
133
134 pub fn resolve_model(&self, requested: Option<&str>) -> Option<stakai::Model> {
135 match requested {
136 Some(requested_model) => self.find_model(requested_model),
137 None => self
138 .default_model
139 .clone()
140 .or_else(|| self.models.first().cloned()),
141 }
142 }
143
144 pub async fn set_pending_tools(
145 &self,
146 session_id: Uuid,
147 run_id: Uuid,
148 tool_calls: Vec<ProposedToolCall>,
149 ) {
150 let mut guard = self.pending_tools.write().await;
151 guard.insert(session_id, PendingToolApprovals { run_id, tool_calls });
152 }
153
154 pub async fn clear_pending_tools(&self, session_id: Uuid, run_id: Uuid) {
155 let mut guard = self.pending_tools.write().await;
156 if guard
157 .get(&session_id)
158 .is_some_and(|pending| pending.run_id == run_id)
159 {
160 guard.remove(&session_id);
161 }
162 }
163
164 pub async fn pending_tools(&self, session_id: Uuid) -> Option<PendingToolApprovals> {
165 let guard = self.pending_tools.read().await;
166 guard.get(&session_id).cloned()
167 }
168
169 fn find_model(&self, requested: &str) -> Option<stakai::Model> {
170 if let Some((provider, id)) = requested.split_once('/') {
171 return self
172 .models
173 .iter()
174 .find(|model| model.provider == provider && model.id == id)
175 .cloned()
176 .or_else(|| Some(stakai::Model::custom(id, provider)));
177 }
178
179 self.models
180 .iter()
181 .find(|model| model.id == requested)
182 .cloned()
183 .or_else(|| {
184 self.default_model.as_ref().map(|default_model| {
185 stakai::Model::custom(requested.to_string(), default_model.provider.clone())
186 })
187 })
188 .or_else(|| {
189 self.models.first().map(|model| {
190 stakai::Model::custom(requested.to_string(), model.provider.clone())
191 })
192 })
193 .or_else(|| Some(stakai::Model::custom(requested.to_string(), "openai")))
194 }
195}