stakpak_api/remote/
mod.rs

1use crate::AgentProvider;
2use crate::models::*;
3use async_trait::async_trait;
4use eventsource_stream::Eventsource;
5use futures_util::Stream;
6use futures_util::StreamExt;
7use reqwest::header::HeaderMap;
8use reqwest::{Client as ReqwestClient, Error as ReqwestError, Response, header};
9use rmcp::model::Content;
10use rmcp::model::JsonRpcResponse;
11use serde::Deserialize;
12use serde_json::json;
13use stakpak_shared::models::integrations::openai::{
14    AgentModel, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse,
15    ChatMessage, Tool,
16};
17use stakpak_shared::tls_client::TlsClientConfig;
18use stakpak_shared::tls_client::create_tls_client;
19use uuid::Uuid;
20
21#[derive(Clone, Debug)]
22pub struct RemoteClient {
23    client: ReqwestClient,
24    base_url: String,
25}
26
27#[derive(Clone, Debug)]
28pub struct ClientConfig {
29    pub api_key: Option<String>,
30    pub api_endpoint: String,
31}
32
33#[derive(Deserialize)]
34struct ApiError {
35    error: ApiErrorDetail,
36}
37
38#[derive(Deserialize)]
39struct ApiErrorDetail {
40    key: String,
41    message: String,
42}
43
44impl RemoteClient {
45    async fn handle_response_error(&self, response: Response) -> Result<Response, String> {
46        if response.status().is_success() {
47            Ok(response)
48        } else {
49            match response.json::<ApiError>().await {
50                Ok(response) => {
51                    if response.error.key == "EXCEEDED_API_LIMIT" {
52                        Err(format!(
53                            "{}.\n\nPlease top up your account at https://stakpak.dev/settings/billing to keep Stakpaking.",
54                            response.error.message
55                        ))
56                    } else {
57                        Err(response.error.message)
58                    }
59                }
60                Err(e) => Err(e.to_string()),
61            }
62        }
63    }
64
65    async fn call_mcp_tool(&self, input: &ToolsCallParams) -> Result<Vec<Content>, String> {
66        let url = format!("{}/mcp", self.base_url);
67
68        let payload = json!({
69            "jsonrpc": "2.0",
70            "method": "tools/call",
71            "params": {
72                "name": input.name,
73                "arguments": input.arguments,
74            },
75            "id": Uuid::new_v4().to_string(),
76        });
77
78        let response = self
79            .client
80            .post(&url)
81            .json(&payload)
82            .send()
83            .await
84            .map_err(|e: ReqwestError| e.to_string())?;
85
86        let response = self.handle_response_error(response).await?;
87
88        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
89
90        match serde_json::from_value::<JsonRpcResponse<ToolsCallResponse>>(value.clone()) {
91            Ok(response) => Ok(response.result.content),
92            Err(e) => {
93                eprintln!("Failed to deserialize response: {}", e);
94                eprintln!("Raw response: {}", value);
95                Err("Failed to deserialize response:".into())
96            }
97        }
98    }
99
100    pub fn new(config: &ClientConfig) -> Result<Self, String> {
101        if config.api_key.is_none() {
102            return Err("API Key not found, please login".into());
103        }
104
105        let mut headers = header::HeaderMap::new();
106        headers.insert(
107            header::AUTHORIZATION,
108            header::HeaderValue::from_str(&format!("Bearer {}", config.api_key.clone().unwrap()))
109                .expect("Invalid API key format"),
110        );
111        headers.insert(
112            header::USER_AGENT,
113            header::HeaderValue::from_str(&format!("Stakpak/{}", env!("CARGO_PKG_VERSION")))
114                .expect("Invalid user agent format"),
115        );
116
117        let client = create_tls_client(
118            TlsClientConfig::default()
119                .with_headers(headers)
120                .with_timeout(std::time::Duration::from_secs(300)),
121        )?;
122
123        Ok(Self {
124            client,
125            base_url: config.api_endpoint.clone() + "/v1",
126        })
127    }
128}
129
130#[async_trait]
131impl AgentProvider for RemoteClient {
132    async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
133        let url = format!("{}/account", self.base_url);
134
135        let response = self
136            .client
137            .get(&url)
138            .send()
139            .await
140            .map_err(|e: ReqwestError| e.to_string())?;
141
142        let response = self.handle_response_error(response).await?;
143
144        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
145        match serde_json::from_value::<GetMyAccountResponse>(value.clone()) {
146            Ok(response) => Ok(response),
147            Err(e) => {
148                eprintln!("Failed to deserialize response: {}", e);
149                eprintln!("Raw response: {}", value);
150                Err("Failed to deserialize response:".into())
151            }
152        }
153    }
154
155    async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
156        let url = format!("{}/rules", self.base_url);
157
158        let response = self
159            .client
160            .get(&url)
161            .send()
162            .await
163            .map_err(|e: ReqwestError| e.to_string())?;
164
165        let response = self.handle_response_error(response).await?;
166
167        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
168        match serde_json::from_value::<ListRulebooksResponse>(value.clone()) {
169            Ok(response) => Ok(response.results),
170            Err(e) => {
171                eprintln!("Failed to deserialize response: {}", e);
172                eprintln!("Raw response: {}", value);
173                Err("Failed to deserialize response:".into())
174            }
175        }
176    }
177
178    async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
179        // URL encode the URI to handle special characters
180        let encoded_uri = urlencoding::encode(uri);
181        let url = format!("{}/rules/{}", self.base_url, encoded_uri);
182
183        let response = self
184            .client
185            .get(&url)
186            .send()
187            .await
188            .map_err(|e: ReqwestError| e.to_string())?;
189
190        let response = self.handle_response_error(response).await?;
191
192        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
193        match serde_json::from_value::<RuleBook>(value.clone()) {
194            Ok(response) => Ok(response),
195            Err(e) => {
196                eprintln!("Failed to deserialize response: {}", e);
197                eprintln!("Raw response: {}", value);
198                Err("Failed to deserialize response:".into())
199            }
200        }
201    }
202
203    async fn create_rulebook(
204        &self,
205        uri: &str,
206        description: &str,
207        content: &str,
208        tags: Vec<String>,
209        visibility: Option<RuleBookVisibility>,
210    ) -> Result<CreateRuleBookResponse, String> {
211        let url = format!("{}/rules", self.base_url);
212
213        let input = CreateRuleBookInput {
214            uri: uri.to_string(),
215            description: description.to_string(),
216            content: content.to_string(),
217            tags,
218            visibility,
219        };
220
221        let response = self
222            .client
223            .post(&url)
224            .json(&input)
225            .send()
226            .await
227            .map_err(|e: ReqwestError| e.to_string())?;
228
229        // Check status before consuming body
230        if !response.status().is_success() {
231            let status = response.status();
232            let error_text = response
233                .text()
234                .await
235                .unwrap_or_else(|_| "Unknown error".to_string());
236            return Err(format!("API error ({}): {}", status, error_text));
237        }
238
239        // Get response as text first to handle non-JSON responses
240        let response_text = response.text().await.map_err(|e| e.to_string())?;
241
242        // Try to parse as JSON first
243        if let Ok(value) = serde_json::from_str::<serde_json::Value>(&response_text) {
244            match serde_json::from_value::<CreateRuleBookResponse>(value.clone()) {
245                Ok(response) => return Ok(response),
246                Err(e) => {
247                    eprintln!("Failed to deserialize JSON response: {}", e);
248                    eprintln!("Raw response: {}", value);
249                }
250            }
251        }
252
253        // If JSON parsing failed, try to parse as plain text "id: <uuid>"
254        if response_text.starts_with("id: ") {
255            let id = response_text.trim_start_matches("id: ").trim().to_string();
256            return Ok(CreateRuleBookResponse { id });
257        }
258
259        Err(format!("Unexpected response format: {}", response_text))
260    }
261
262    async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
263        let encoded_uri = urlencoding::encode(uri);
264        let url = format!("{}/rules/{}", self.base_url, encoded_uri);
265
266        let response = self
267            .client
268            .delete(&url)
269            .send()
270            .await
271            .map_err(|e: ReqwestError| e.to_string())?;
272
273        let _response = self.handle_response_error(response).await?;
274
275        Ok(())
276    }
277
278    async fn list_agent_sessions(&self) -> Result<Vec<AgentSession>, String> {
279        let url = format!("{}/agents/sessions", self.base_url);
280
281        let response = self
282            .client
283            .get(&url)
284            .send()
285            .await
286            .map_err(|e: ReqwestError| e.to_string())?;
287
288        let response = self.handle_response_error(response).await?;
289
290        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
291        match serde_json::from_value::<Vec<AgentSession>>(value.clone()) {
292            Ok(response) => Ok(response),
293            Err(e) => {
294                eprintln!("Failed to deserialize response: {}", e);
295                eprintln!("Raw response: {}", value);
296                Err("Failed to deserialize response:".into())
297            }
298        }
299    }
300
301    async fn get_agent_session(&self, session_id: Uuid) -> Result<AgentSession, String> {
302        let url = format!("{}/agents/sessions/{}", self.base_url, session_id);
303
304        let response = self
305            .client
306            .get(&url)
307            .send()
308            .await
309            .map_err(|e: ReqwestError| e.to_string())?;
310
311        let response = self.handle_response_error(response).await?;
312
313        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
314
315        match serde_json::from_value::<AgentSession>(value.clone()) {
316            Ok(response) => Ok(response),
317            Err(e) => {
318                eprintln!("Failed to deserialize response: {}", e);
319                eprintln!("Raw response: {}", value);
320                Err("Failed to deserialize response:".into())
321            }
322        }
323    }
324
325    async fn get_agent_session_stats(&self, session_id: Uuid) -> Result<AgentSessionStats, String> {
326        let url = format!("{}/agents/sessions/{}/stats", self.base_url, session_id);
327
328        let response = self
329            .client
330            .get(&url)
331            .send()
332            .await
333            .map_err(|e: ReqwestError| e.to_string())?;
334
335        let response = self.handle_response_error(response).await?;
336
337        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
338
339        match serde_json::from_value::<AgentSessionStats>(value.clone()) {
340            Ok(response) => Ok(response),
341            Err(e) => {
342                eprintln!("Failed to deserialize response: {}", e);
343                eprintln!("Raw response: {}", value);
344                Err("Failed to deserialize response:".into())
345            }
346        }
347    }
348
349    async fn get_agent_checkpoint(&self, checkpoint_id: Uuid) -> Result<RunAgentOutput, String> {
350        let url = format!("{}/agents/checkpoints/{}", self.base_url, checkpoint_id);
351
352        let response = self
353            .client
354            .get(&url)
355            .send()
356            .await
357            .map_err(|e: ReqwestError| e.to_string())?;
358
359        let response = self.handle_response_error(response).await?;
360
361        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
362        match serde_json::from_value::<RunAgentOutput>(value.clone()) {
363            Ok(response) => Ok(response),
364            Err(e) => {
365                eprintln!("Failed to deserialize response: {}", e);
366                eprintln!("Raw response: {}", value);
367                Err("Failed to deserialize response:".into())
368            }
369        }
370    }
371
372    async fn get_agent_session_latest_checkpoint(
373        &self,
374        session_id: Uuid,
375    ) -> Result<RunAgentOutput, String> {
376        let url = format!(
377            "{}/agents/sessions/{}/checkpoints/latest",
378            self.base_url, session_id
379        );
380
381        let response = self
382            .client
383            .get(&url)
384            .send()
385            .await
386            .map_err(|e: ReqwestError| e.to_string())?;
387
388        let response = self.handle_response_error(response).await?;
389
390        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
391        match serde_json::from_value::<RunAgentOutput>(value.clone()) {
392            Ok(response) => Ok(response),
393            Err(e) => {
394                eprintln!("Failed to deserialize response: {}", e);
395                eprintln!("Raw response: {}", value);
396                Err("Failed to deserialize response:".into())
397            }
398        }
399    }
400
401    async fn chat_completion(
402        &self,
403        model: AgentModel,
404        messages: Vec<ChatMessage>,
405        tools: Option<Vec<Tool>>,
406    ) -> Result<ChatCompletionResponse, String> {
407        let url = format!("{}/agents/openai/v1/chat/completions", self.base_url);
408
409        let input = ChatCompletionRequest::new(model.to_string(), messages, tools, None);
410
411        let response = self
412            .client
413            .post(&url)
414            .json(&input)
415            .send()
416            .await
417            .map_err(|e: ReqwestError| e.to_string())?;
418
419        let response = self.handle_response_error(response).await?;
420
421        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
422
423        match serde_json::from_value::<ChatCompletionResponse>(value.clone()) {
424            Ok(response) => Ok(response),
425            Err(e) => {
426                eprintln!("Failed to deserialize response: {}", e);
427                eprintln!("Raw response: {}", value);
428                Err("Failed to deserialize response:".into())
429            }
430        }
431    }
432
433    async fn chat_completion_stream(
434        &self,
435        model: AgentModel,
436        messages: Vec<ChatMessage>,
437        tools: Option<Vec<Tool>>,
438        headers: Option<HeaderMap>,
439    ) -> Result<
440        (
441            std::pin::Pin<
442                Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
443            >,
444            Option<String>,
445        ),
446        String,
447    > {
448        let url = format!("{}/agents/openai/v1/chat/completions", self.base_url);
449
450        let input = ChatCompletionRequest::new(model.to_string(), messages, tools, Some(true));
451
452        let response = self
453            .client
454            .post(&url)
455            .headers(headers.unwrap_or_default())
456            .json(&input)
457            .send()
458            .await
459            .map_err(|e: ReqwestError| e.to_string())?;
460
461        // Check content-type before processing
462        let content_type = response
463            .headers()
464            .get("content-type")
465            .and_then(|v| v.to_str().ok())
466            .unwrap_or("unknown");
467
468        // Extract x-request-id from headers
469        let request_id = response
470            .headers()
471            .get("x-request-id")
472            .and_then(|v| v.to_str().ok())
473            .map(|s| s.to_string());
474
475        // If content-type is not event-stream, it's likely an error message
476        if !content_type.contains("event-stream") && !content_type.contains("text/event-stream") {
477            let status = response.status();
478            let error_body = response
479                .text()
480                .await
481                .unwrap_or_else(|_| "Failed to read error body".to_string());
482            return Err(format!(
483                "Server returned non-stream response ({}): {}",
484                status, error_body
485            ));
486        }
487
488        let response = self.handle_response_error(response).await?;
489        let stream = response.bytes_stream().eventsource().map(|event| {
490            event
491                .map_err(|err| {
492                    eprintln!("stream: failed to read response: {:?}", err);
493                    ApiStreamError::Unknown("Failed to read response".to_string())
494                })
495                .and_then(|event| match event.event.as_str() {
496                    "error" => Err(ApiStreamError::from(event.data)),
497                    _ => serde_json::from_str::<ChatCompletionStreamResponse>(&event.data).map_err(
498                        |_| {
499                            ApiStreamError::Unknown(
500                                "Failed to parse JSON from Anthropic response".to_string(),
501                            )
502                        },
503                    ),
504                })
505        });
506
507        Ok((Box::pin(stream), request_id))
508    }
509
510    async fn cancel_stream(&self, request_id: String) -> Result<(), String> {
511        let url = format!("{}/agents/requests/{}/cancel", self.base_url, request_id);
512        self.client
513            .post(&url)
514            .send()
515            .await
516            .map_err(|e: ReqwestError| e.to_string())?;
517
518        Ok(())
519    }
520
521    // async fn build_code_index(
522    //     &self,
523    //     input: &BuildCodeIndexInput,
524    // ) -> Result<BuildCodeIndexOutput, String> {
525    //     let url = format!("{}/commands/build_code_index", self.base_url,);
526
527    //     let response = self
528    //         .client
529    //         .post(&url)
530    //         .json(&input)
531    //         .send()
532    //         .await
533    //         .map_err(|e: ReqwestError| e.to_string())?;
534
535    //     let response = self.handle_response_error(response).await?;
536
537    //     let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
538    //     match serde_json::from_value::<BuildCodeIndexOutput>(value.clone()) {
539    //         Ok(response) => Ok(response),
540    //         Err(e) => {
541    //             eprintln!("Failed to deserialize response: {}", e);
542    //             eprintln!("Raw response: {}", value);
543    //             Err("Failed to deserialize response:".into())
544    //         }
545    //     }
546    // }
547
548    async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
549        self.call_mcp_tool(&ToolsCallParams {
550            name: "search_docs".to_string(),
551            arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
552        })
553        .await
554    }
555
556    async fn search_memory(&self, input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
557        self.call_mcp_tool(&ToolsCallParams {
558            name: "search_memory".to_string(),
559            arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
560        })
561        .await
562    }
563
564    async fn slack_read_messages(
565        &self,
566        input: &SlackReadMessagesRequest,
567    ) -> Result<Vec<Content>, String> {
568        self.call_mcp_tool(&ToolsCallParams {
569            name: "slack_read_messages".to_string(),
570            arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
571        })
572        .await
573    }
574
575    async fn slack_read_replies(
576        &self,
577        input: &SlackReadRepliesRequest,
578    ) -> Result<Vec<Content>, String> {
579        self.call_mcp_tool(&ToolsCallParams {
580            name: "slack_read_replies".to_string(),
581            arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
582        })
583        .await
584    }
585
586    async fn slack_send_message(
587        &self,
588        input: &SlackSendMessageRequest,
589    ) -> Result<Vec<Content>, String> {
590        // Note: The remote tool expects "markdown_text" but the struct has "mrkdwn_text".
591        // We need to map this correctly. The struct in models.rs has mrkdwn_text.
592        // The remote tool likely expects what was previously passed.
593        // In slack.rs, it was mapping "mrkdwn_text" to "markdown_text".
594        // So we should construct the arguments manually or use a custom serializer if we want to match exactly.
595        // However, since we are sending `input` which is `SlackSendMessageRequest`, let's check its definition.
596        // It has `mrkdwn_text`.
597        // The previous implementation in slack.rs did:
598        // arguments: json!({
599        //     "channel": channel,
600        //     "markdown_text": mrkdwn_text,
601        //     "thread_ts": thread_ts,
602        // }),
603        // So we need to replicate this mapping.
604
605        let arguments = json!({
606            "channel": input.channel,
607            "markdown_text": input.mrkdwn_text,
608            "thread_ts": input.thread_ts,
609        });
610
611        self.call_mcp_tool(&ToolsCallParams {
612            name: "slack_send_message".to_string(),
613            arguments,
614        })
615        .await
616    }
617
618    async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
619        let url = format!(
620            "{}/agents/sessions/checkpoints/{}/extract-memory",
621            self.base_url, checkpoint_id
622        );
623
624        let response = self
625            .client
626            .post(&url)
627            .send()
628            .await
629            .map_err(|e: ReqwestError| e.to_string())?;
630
631        let _ = self.handle_response_error(response).await?;
632        Ok(())
633    }
634}