1use crate::{
2 checkpoint_store::CheckpointStore, event_log::EventLog, idempotency::IdempotencyStore,
3 session_manager::SessionManager,
4};
5use stakpak_agent_core::{ProposedToolCall, ToolApprovalPolicy};
6use stakpak_api::SessionStorage;
7use stakpak_mcp_client::McpClient;
8use std::{collections::HashMap, sync::Arc, time::Instant};
9use tokio::sync::{RwLock, broadcast};
10use uuid::Uuid;
11
12#[derive(Debug, Clone)]
13pub struct PendingToolApprovals {
14 pub run_id: Uuid,
15 pub tool_calls: Vec<ProposedToolCall>,
16}
17
18#[derive(Clone)]
19pub struct AppState {
20 pub run_manager: SessionManager,
21 pub session_store: Arc<dyn SessionStorage>,
23 pub events: Arc<EventLog>,
24 pub idempotency: Arc<IdempotencyStore>,
25 pub inference: Arc<stakai::Inference>,
26 pub checkpoint_store: Arc<CheckpointStore>,
28 pub models: Arc<Vec<stakai::Model>>,
29 pub default_model: Option<stakai::Model>,
30 pub tool_approval_policy: ToolApprovalPolicy,
31 pub started_at: Instant,
32 pub mcp_client: Option<Arc<McpClient>>,
33 pub mcp_tools: Arc<RwLock<Vec<stakai::Tool>>>,
34 pub mcp_server_shutdown_tx: Option<broadcast::Sender<()>>,
35 pub mcp_proxy_shutdown_tx: Option<broadcast::Sender<()>>,
36 pending_tools: Arc<RwLock<HashMap<Uuid, PendingToolApprovals>>>,
37}
38
39impl AppState {
40 #[allow(clippy::too_many_arguments)]
41 pub fn new(
42 session_store: Arc<dyn SessionStorage>,
43 events: Arc<EventLog>,
44 idempotency: Arc<IdempotencyStore>,
45 inference: Arc<stakai::Inference>,
46 models: Vec<stakai::Model>,
47 default_model: Option<stakai::Model>,
48 tool_approval_policy: ToolApprovalPolicy,
49 ) -> Self {
50 Self {
51 run_manager: SessionManager::new(),
52 session_store,
53 events,
54 idempotency,
55 inference,
56 checkpoint_store: Arc::new(CheckpointStore::default_local()),
57 models: Arc::new(models),
58 default_model,
59 tool_approval_policy,
60 started_at: Instant::now(),
61 mcp_client: None,
62 mcp_tools: Arc::new(RwLock::new(Vec::new())),
63 mcp_server_shutdown_tx: None,
64 mcp_proxy_shutdown_tx: None,
65 pending_tools: Arc::new(RwLock::new(HashMap::new())),
66 }
67 }
68
69 pub fn with_mcp(
70 mut self,
71 mcp_client: Arc<McpClient>,
72 mcp_tools: Vec<stakai::Tool>,
73 mcp_server_shutdown_tx: Option<broadcast::Sender<()>>,
74 mcp_proxy_shutdown_tx: Option<broadcast::Sender<()>>,
75 ) -> Self {
76 self.mcp_client = Some(mcp_client);
77 self.mcp_tools = Arc::new(RwLock::new(mcp_tools));
78 self.mcp_server_shutdown_tx = mcp_server_shutdown_tx;
79 self.mcp_proxy_shutdown_tx = mcp_proxy_shutdown_tx;
80 self
81 }
82
83 pub fn with_checkpoint_store(mut self, checkpoint_store: Arc<CheckpointStore>) -> Self {
84 self.checkpoint_store = checkpoint_store;
85 self
86 }
87
88 pub async fn current_mcp_tools(&self) -> Vec<stakai::Tool> {
89 self.mcp_tools.read().await.clone()
90 }
91
92 pub async fn refresh_mcp_tools(&self) -> Result<usize, String> {
93 let Some(mcp_client) = self.mcp_client.as_ref() else {
94 return Ok(self.mcp_tools.read().await.len());
95 };
96
97 let raw_tools = stakpak_mcp_client::get_tools(mcp_client)
98 .await
99 .map_err(|error| format!("Failed to refresh MCP tools: {error}"))?;
100
101 let converted = raw_tools
102 .into_iter()
103 .map(|tool| stakai::Tool {
104 tool_type: "function".to_string(),
105 function: stakai::ToolFunction {
106 name: tool.name.as_ref().to_string(),
107 description: tool
108 .description
109 .as_ref()
110 .map(std::string::ToString::to_string)
111 .unwrap_or_default(),
112 parameters: serde_json::Value::Object((*tool.input_schema).clone()),
113 },
114 provider_options: None,
115 })
116 .collect::<Vec<_>>();
117
118 let mut guard = self.mcp_tools.write().await;
119 *guard = converted;
120 Ok(guard.len())
121 }
122
123 pub fn uptime_seconds(&self) -> u64 {
124 self.started_at.elapsed().as_secs()
125 }
126
127 pub fn resolve_model(&self, requested: Option<&str>) -> Option<stakai::Model> {
128 match requested {
129 Some(requested_model) => self.find_model(requested_model),
130 None => self
131 .default_model
132 .clone()
133 .or_else(|| self.models.first().cloned()),
134 }
135 }
136
137 pub async fn set_pending_tools(
138 &self,
139 session_id: Uuid,
140 run_id: Uuid,
141 tool_calls: Vec<ProposedToolCall>,
142 ) {
143 let mut guard = self.pending_tools.write().await;
144 guard.insert(session_id, PendingToolApprovals { run_id, tool_calls });
145 }
146
147 pub async fn clear_pending_tools(&self, session_id: Uuid, run_id: Uuid) {
148 let mut guard = self.pending_tools.write().await;
149 if guard
150 .get(&session_id)
151 .is_some_and(|pending| pending.run_id == run_id)
152 {
153 guard.remove(&session_id);
154 }
155 }
156
157 pub async fn pending_tools(&self, session_id: Uuid) -> Option<PendingToolApprovals> {
158 let guard = self.pending_tools.read().await;
159 guard.get(&session_id).cloned()
160 }
161
162 fn find_model(&self, requested: &str) -> Option<stakai::Model> {
163 if let Some((provider, id)) = requested.split_once('/') {
164 return self
165 .models
166 .iter()
167 .find(|model| model.provider == provider && model.id == id)
168 .cloned()
169 .or_else(|| Some(stakai::Model::custom(id, provider)));
170 }
171
172 self.models
173 .iter()
174 .find(|model| model.id == requested)
175 .cloned()
176 .or_else(|| {
177 self.default_model.as_ref().map(|default_model| {
178 stakai::Model::custom(requested.to_string(), default_model.provider.clone())
179 })
180 })
181 .or_else(|| {
182 self.models.first().map(|model| {
183 stakai::Model::custom(requested.to_string(), model.provider.clone())
184 })
185 })
186 .or_else(|| Some(stakai::Model::custom(requested.to_string(), "openai")))
187 }
188}