1use super::{
6 CheckpointState, CreateCheckpointRequest, CreateCheckpointResponse, CreateSessionRequest,
7 CreateSessionResponse, GetCheckpointResponse, GetSessionResponse, ListCheckpointsQuery,
8 ListCheckpointsResponse, ListSessionsQuery, ListSessionsResponse, SessionVisibility,
9 StakpakApiConfig, UpdateSessionRequest, UpdateSessionResponse, models::*,
10};
11use crate::models::{
12 CreateRuleBookInput, CreateRuleBookResponse, GetMyAccountResponse, ListRuleBook,
13 ListRulebooksResponse, RuleBook,
14};
15use reqwest::{Client as ReqwestClient, Response, header};
16use rmcp::model::Content;
17use serde::de::DeserializeOwned;
18use serde_json::{Value, json};
19use stakpak_shared::models::billing::BillingResponse;
20use uuid::Uuid;
21
22#[derive(Clone, Debug)]
24pub struct StakpakApiClient {
25 client: ReqwestClient,
26 base_url: String,
27}
28
29#[derive(Debug, serde::Deserialize)]
31struct ApiError {
32 error: ApiErrorDetail,
33}
34
35#[derive(Debug, serde::Deserialize)]
36struct ApiErrorDetail {
37 key: String,
38 message: String,
39}
40
41impl StakpakApiClient {
42 pub fn new(config: &StakpakApiConfig) -> Result<Self, String> {
44 if config.api_key.is_empty() {
45 return Err("Stakpak API key is required".to_string());
46 }
47
48 let mut headers = header::HeaderMap::new();
49 headers.insert(
50 header::AUTHORIZATION,
51 header::HeaderValue::from_str(&format!("Bearer {}", config.api_key))
52 .map_err(|e| e.to_string())?,
53 );
54 headers.insert(
55 header::USER_AGENT,
56 header::HeaderValue::from_str(&format!("Stakpak/{}", env!("CARGO_PKG_VERSION")))
57 .map_err(|e| e.to_string())?,
58 );
59
60 let client = ReqwestClient::builder()
61 .default_headers(headers)
62 .timeout(std::time::Duration::from_secs(300))
63 .build()
64 .map_err(|e| e.to_string())?;
65
66 Ok(Self {
67 client,
68 base_url: config.api_endpoint.clone(),
69 })
70 }
71
72 pub async fn create_session(
78 &self,
79 req: &CreateSessionRequest,
80 ) -> Result<CreateSessionResponse, String> {
81 let url = format!("{}/v1/sessions", self.base_url);
82 let response = self
83 .client
84 .post(&url)
85 .json(req)
86 .send()
87 .await
88 .map_err(|e| e.to_string())?;
89 self.handle_response(response).await
90 }
91
92 pub async fn create_checkpoint(
94 &self,
95 session_id: Uuid,
96 req: &CreateCheckpointRequest,
97 ) -> Result<CreateCheckpointResponse, String> {
98 let url = format!("{}/v1/sessions/{}/checkpoints", self.base_url, session_id);
99 let response = self
100 .client
101 .post(&url)
102 .json(req)
103 .send()
104 .await
105 .map_err(|e| e.to_string())?;
106 self.handle_response(response).await
107 }
108
109 pub async fn list_sessions(
111 &self,
112 query: &ListSessionsQuery,
113 ) -> Result<ListSessionsResponse, String> {
114 let url = format!("{}/v1/sessions", self.base_url);
115 let response = self
116 .client
117 .get(&url)
118 .query(query)
119 .send()
120 .await
121 .map_err(|e| e.to_string())?;
122 self.handle_response(response).await
123 }
124
125 pub async fn get_session(&self, id: Uuid) -> Result<GetSessionResponse, String> {
127 let url = format!("{}/v1/sessions/{}", self.base_url, id);
128 let response = self
129 .client
130 .get(&url)
131 .send()
132 .await
133 .map_err(|e| e.to_string())?;
134 self.handle_response(response).await
135 }
136
137 pub async fn update_session(
139 &self,
140 id: Uuid,
141 req: &UpdateSessionRequest,
142 ) -> Result<UpdateSessionResponse, String> {
143 let url = format!("{}/v1/sessions/{}", self.base_url, id);
144 let response = self
145 .client
146 .patch(&url)
147 .json(req)
148 .send()
149 .await
150 .map_err(|e| e.to_string())?;
151 self.handle_response(response).await
152 }
153
154 pub async fn delete_session(&self, id: Uuid) -> Result<(), String> {
156 let url = format!("{}/v1/sessions/{}", self.base_url, id);
157 let response = self
158 .client
159 .delete(&url)
160 .send()
161 .await
162 .map_err(|e| e.to_string())?;
163 self.handle_response_no_body(response).await
164 }
165
166 pub async fn list_checkpoints(
168 &self,
169 session_id: Uuid,
170 query: &ListCheckpointsQuery,
171 ) -> Result<ListCheckpointsResponse, String> {
172 let url = format!("{}/v1/sessions/{}/checkpoints", self.base_url, session_id);
173 let response = self
174 .client
175 .get(&url)
176 .query(query)
177 .send()
178 .await
179 .map_err(|e| e.to_string())?;
180 self.handle_response(response).await
181 }
182
183 pub async fn get_checkpoint(&self, id: Uuid) -> Result<GetCheckpointResponse, String> {
185 let url = format!("{}/v1/sessions/checkpoints/{}", self.base_url, id);
186 let response = self
187 .client
188 .get(&url)
189 .send()
190 .await
191 .map_err(|e| e.to_string())?;
192 self.handle_response(response).await
193 }
194
195 pub async fn cancel_request(&self, request_id: &str) -> Result<(), String> {
201 let url = format!("{}/v1/chat/requests/{}/cancel", self.base_url, request_id);
202 let response = self
203 .client
204 .post(&url)
205 .send()
206 .await
207 .map_err(|e| e.to_string())?;
208 self.handle_response_no_body(response).await
209 }
210
211 pub async fn get_account(&self) -> Result<GetMyAccountResponse, String> {
217 let url = format!("{}/v1/account", self.base_url);
218 let response = self
219 .client
220 .get(&url)
221 .send()
222 .await
223 .map_err(|e| e.to_string())?;
224 self.handle_response(response).await
225 }
226
227 pub async fn get_billing(&self, username: &str) -> Result<BillingResponse, String> {
229 let url = format!("{}/v2/{}/billing", self.base_url, username);
230 let response = self
231 .client
232 .get(&url)
233 .send()
234 .await
235 .map_err(|e| e.to_string())?;
236 self.handle_response(response).await
237 }
238
239 pub async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
245 let url = format!("{}/v1/rules", self.base_url);
246 let response = self
247 .client
248 .get(&url)
249 .send()
250 .await
251 .map_err(|e| e.to_string())?;
252
253 let response = self.handle_response_error(response).await?;
254 let value: Value = response.json().await.map_err(|e| e.to_string())?;
255
256 match serde_json::from_value::<ListRulebooksResponse>(value) {
257 Ok(response) => Ok(response.results),
258 Err(e) => Err(format!("Failed to deserialize rulebooks response: {}", e)),
259 }
260 }
261
262 pub async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
264 let url = format!("{}/v1/rules/{}", self.base_url, uri);
265 let response = self
266 .client
267 .get(&url)
268 .send()
269 .await
270 .map_err(|e| e.to_string())?;
271 self.handle_response(response).await
272 }
273
274 pub async fn create_rulebook(
276 &self,
277 input: &CreateRuleBookInput,
278 ) -> Result<CreateRuleBookResponse, String> {
279 let url = format!("{}/v1/rules", self.base_url);
280 let response = self
281 .client
282 .post(&url)
283 .json(input)
284 .send()
285 .await
286 .map_err(|e| e.to_string())?;
287 self.handle_response(response).await
288 }
289
290 pub async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
292 let url = format!("{}/v1/rules/{}", self.base_url, uri);
293 let response = self
294 .client
295 .delete(&url)
296 .send()
297 .await
298 .map_err(|e| e.to_string())?;
299 self.handle_response_no_body(response).await
300 }
301
302 pub async fn search_docs(&self, req: &SearchDocsRequest) -> Result<Vec<Content>, String> {
308 self.call_mcp_tool(&ToolsCallParams {
309 name: "search_docs".to_string(),
310 arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
311 })
312 .await
313 }
314
315 pub async fn search_memory(&self, req: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
317 self.call_mcp_tool(&ToolsCallParams {
318 name: "search_memory".to_string(),
319 arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
320 })
321 .await
322 }
323
324 pub async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
326 let url = format!(
327 "{}/v1/agents/sessions/checkpoints/{}/extract-memory",
328 self.base_url, checkpoint_id
329 );
330 let response = self
331 .client
332 .post(&url)
333 .send()
334 .await
335 .map_err(|e| e.to_string())?;
336 self.handle_response_no_body(response).await
337 }
338
339 pub async fn slack_read_messages(
341 &self,
342 req: &SlackReadMessagesRequest,
343 ) -> Result<Vec<Content>, String> {
344 self.call_mcp_tool(&ToolsCallParams {
345 name: "slack_read_messages".to_string(),
346 arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
347 })
348 .await
349 }
350
351 pub async fn slack_read_replies(
353 &self,
354 req: &SlackReadRepliesRequest,
355 ) -> Result<Vec<Content>, String> {
356 self.call_mcp_tool(&ToolsCallParams {
357 name: "slack_read_replies".to_string(),
358 arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
359 })
360 .await
361 }
362
363 pub async fn slack_send_message(
365 &self,
366 req: &SlackSendMessageRequest,
367 ) -> Result<Vec<Content>, String> {
368 self.call_mcp_tool(&ToolsCallParams {
369 name: "slack_send_message".to_string(),
370 arguments: serde_json::to_value(req).map_err(|e| e.to_string())?,
371 })
372 .await
373 }
374
375 async fn call_mcp_tool(&self, params: &ToolsCallParams) -> Result<Vec<Content>, String> {
381 let url = format!("{}/v1/mcp", self.base_url);
382 let body = json!({
383 "jsonrpc": "2.0",
384 "id": 1,
385 "method": "tools/call",
386 "params": params
387 });
388
389 let response = self
390 .client
391 .post(&url)
392 .json(&body)
393 .send()
394 .await
395 .map_err(|e| e.to_string())?;
396
397 let resp: Value = self.handle_response(response).await?;
398
399 if let Some(result) = resp.get("result")
401 && let Some(content) = result.get("content")
402 {
403 let content: Vec<Content> =
404 serde_json::from_value(content.clone()).map_err(|e| e.to_string())?;
405 return Ok(content);
406 }
407
408 if let Some(error) = resp.get("error") {
410 let msg = error
411 .get("message")
412 .and_then(|m| m.as_str())
413 .unwrap_or("Unknown error");
414 return Err(msg.to_string());
415 }
416
417 Err("Invalid MCP response format".to_string())
418 }
419
420 async fn handle_response<T: DeserializeOwned>(&self, response: Response) -> Result<T, String> {
422 let response = self.handle_response_error(response).await?;
423 response.json().await.map_err(|e| e.to_string())
424 }
425
426 async fn handle_response_no_body(&self, response: Response) -> Result<(), String> {
428 self.handle_response_error(response).await?;
429 Ok(())
430 }
431
432 async fn handle_response_error(&self, response: Response) -> Result<Response, String> {
434 if response.status().is_success() {
435 return Ok(response);
436 }
437
438 let status = response.status();
439 let error_body = response.text().await.unwrap_or_default();
440
441 if let Ok(api_error) = serde_json::from_str::<ApiError>(&error_body) {
443 if api_error.error.key == "EXCEEDED_API_LIMIT" {
445 return Err(format!(
446 "{}. You can top up your billing at https://stakpak.dev/settings/billing",
447 api_error.error.message
448 ));
449 }
450 return Err(api_error.error.message);
451 }
452
453 Err(format!("API error {}: {}", status, error_body))
454 }
455}
456
457impl CreateSessionRequest {
462 pub fn new(title: impl Into<String>, state: CheckpointState) -> Self {
464 Self {
465 title: title.into(),
466 visibility: Some(SessionVisibility::Private),
467 cwd: None,
468 state,
469 }
470 }
471
472 pub fn with_cwd(mut self, cwd: impl Into<String>) -> Self {
474 self.cwd = Some(cwd.into());
475 self
476 }
477
478 pub fn with_visibility(mut self, visibility: SessionVisibility) -> Self {
480 self.visibility = Some(visibility);
481 self
482 }
483}
484
485impl CreateCheckpointRequest {
486 pub fn new(state: CheckpointState) -> Self {
488 Self {
489 state,
490 parent_id: None,
491 }
492 }
493
494 pub fn with_parent(mut self, parent_id: Uuid) -> Self {
496 self.parent_id = Some(parent_id);
497 self
498 }
499}