stakpak_api/remote/
mod.rs

1use crate::AgentProvider;
2use crate::models::*;
3use async_trait::async_trait;
4use eventsource_stream::Eventsource;
5use futures_util::Stream;
6use futures_util::StreamExt;
7use reqwest::header::HeaderMap;
8use reqwest::{Client as ReqwestClient, Error as ReqwestError, Response, header};
9use rmcp::model::Content;
10use rmcp::model::JsonRpcResponse;
11use serde::Deserialize;
12use serde_json::json;
13use stakpak_shared::models::integrations::openai::{
14    AgentModel, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse,
15    ChatMessage, Tool,
16};
17use stakpak_shared::tls_client::TlsClientConfig;
18use stakpak_shared::tls_client::create_tls_client;
19use uuid::Uuid;
20
21#[derive(Clone, Debug)]
22pub struct RemoteClient {
23    client: ReqwestClient,
24    base_url: String,
25}
26
27#[derive(Clone, Debug)]
28pub struct ClientConfig {
29    pub api_key: Option<String>,
30    pub api_endpoint: String,
31}
32
33#[derive(Deserialize)]
34struct ApiError {
35    error: ApiErrorDetail,
36}
37
38#[derive(Deserialize)]
39struct ApiErrorDetail {
40    key: String,
41    message: String,
42}
43
44impl RemoteClient {
45    async fn handle_response_error(&self, response: Response) -> Result<Response, String> {
46        if response.status().is_success() {
47            Ok(response)
48        } else {
49            let error_body = response
50                .text()
51                .await
52                .unwrap_or_else(|_| "Failed to read error body".to_string());
53
54            if let Ok(json) = serde_json::from_str::<serde_json::Value>(&error_body) {
55                if let Ok(api_error) = serde_json::from_value::<ApiError>(json.clone()) {
56                    if api_error.error.key == "EXCEEDED_API_LIMIT" {
57                        return Err(format!(
58                            "{}.\n\nPlease top up your account at https://stakpak.dev/settings/billing to keep Stakpaking.",
59                            api_error.error.message
60                        ));
61                    } else {
62                        return Err(api_error.error.message);
63                    }
64                }
65
66                if let Some(error_obj) = json.get("error") {
67                    let error_message =
68                        if let Some(message) = error_obj.get("message").and_then(|m| m.as_str()) {
69                            message.to_string()
70                        } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
71                            format!("API error: {}", code)
72                        } else if let Some(key) = error_obj.get("key").and_then(|k| k.as_str()) {
73                            format!("API error: {}", key)
74                        } else {
75                            serde_json::to_string(error_obj)
76                                .unwrap_or_else(|_| "Unknown API error".to_string())
77                        };
78                    return Err(error_message);
79                }
80            }
81
82            Err(error_body)
83        }
84    }
85
86    async fn call_mcp_tool(&self, input: &ToolsCallParams) -> Result<Vec<Content>, String> {
87        let url = format!("{}/mcp", self.base_url);
88
89        let payload = json!({
90            "jsonrpc": "2.0",
91            "method": "tools/call",
92            "params": {
93                "name": input.name,
94                "arguments": input.arguments,
95            },
96            "id": Uuid::new_v4().to_string(),
97        });
98
99        let response = self
100            .client
101            .post(&url)
102            .json(&payload)
103            .send()
104            .await
105            .map_err(|e: ReqwestError| e.to_string())?;
106
107        let response = self.handle_response_error(response).await?;
108
109        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
110
111        match serde_json::from_value::<JsonRpcResponse<ToolsCallResponse>>(value.clone()) {
112            Ok(response) => Ok(response.result.content),
113            Err(e) => {
114                eprintln!("Failed to deserialize response: {}", e);
115                eprintln!("Raw response: {}", value);
116                Err("Failed to deserialize response:".into())
117            }
118        }
119    }
120
121    pub fn new(config: &ClientConfig) -> Result<Self, String> {
122        if config.api_key.is_none() {
123            return Err("API Key not found, please login".into());
124        }
125
126        let mut headers = header::HeaderMap::new();
127        headers.insert(
128            header::AUTHORIZATION,
129            header::HeaderValue::from_str(&format!("Bearer {}", config.api_key.clone().unwrap()))
130                .expect("Invalid API key format"),
131        );
132        headers.insert(
133            header::USER_AGENT,
134            header::HeaderValue::from_str(&format!("Stakpak/{}", env!("CARGO_PKG_VERSION")))
135                .expect("Invalid user agent format"),
136        );
137
138        let client = create_tls_client(
139            TlsClientConfig::default()
140                .with_headers(headers)
141                .with_timeout(std::time::Duration::from_secs(300)),
142        )?;
143
144        Ok(Self {
145            client,
146            base_url: config.api_endpoint.clone() + "/v1",
147        })
148    }
149}
150
151#[async_trait]
152impl AgentProvider for RemoteClient {
153    async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
154        let url = format!("{}/account", self.base_url);
155
156        let response = self
157            .client
158            .get(&url)
159            .send()
160            .await
161            .map_err(|e: ReqwestError| e.to_string())?;
162
163        let response = self.handle_response_error(response).await?;
164
165        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
166        match serde_json::from_value::<GetMyAccountResponse>(value.clone()) {
167            Ok(response) => Ok(response),
168            Err(e) => {
169                eprintln!("Failed to deserialize response: {}", e);
170                eprintln!("Raw response: {}", value);
171                Err("Failed to deserialize response:".into())
172            }
173        }
174    }
175
176    async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
177        let url = format!("{}/rules", self.base_url);
178
179        let response = self
180            .client
181            .get(&url)
182            .send()
183            .await
184            .map_err(|e: ReqwestError| e.to_string())?;
185
186        let response = self.handle_response_error(response).await?;
187
188        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
189        match serde_json::from_value::<ListRulebooksResponse>(value.clone()) {
190            Ok(response) => Ok(response.results),
191            Err(e) => {
192                eprintln!("Failed to deserialize response: {}", e);
193                eprintln!("Raw response: {}", value);
194                Err("Failed to deserialize response:".into())
195            }
196        }
197    }
198
199    async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
200        // URL encode the URI to handle special characters
201        let encoded_uri = urlencoding::encode(uri);
202        let url = format!("{}/rules/{}", self.base_url, encoded_uri);
203
204        let response = self
205            .client
206            .get(&url)
207            .send()
208            .await
209            .map_err(|e: ReqwestError| e.to_string())?;
210
211        let response = self.handle_response_error(response).await?;
212
213        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
214        match serde_json::from_value::<RuleBook>(value.clone()) {
215            Ok(response) => Ok(response),
216            Err(e) => {
217                eprintln!("Failed to deserialize response: {}", e);
218                eprintln!("Raw response: {}", value);
219                Err("Failed to deserialize response:".into())
220            }
221        }
222    }
223
224    async fn create_rulebook(
225        &self,
226        uri: &str,
227        description: &str,
228        content: &str,
229        tags: Vec<String>,
230        visibility: Option<RuleBookVisibility>,
231    ) -> Result<CreateRuleBookResponse, String> {
232        let url = format!("{}/rules", self.base_url);
233
234        let input = CreateRuleBookInput {
235            uri: uri.to_string(),
236            description: description.to_string(),
237            content: content.to_string(),
238            tags,
239            visibility,
240        };
241
242        let response = self
243            .client
244            .post(&url)
245            .json(&input)
246            .send()
247            .await
248            .map_err(|e: ReqwestError| e.to_string())?;
249
250        // Check status before consuming body
251        if !response.status().is_success() {
252            let status = response.status();
253            let error_text = response
254                .text()
255                .await
256                .unwrap_or_else(|_| "Unknown error".to_string());
257            return Err(format!("API error ({}): {}", status, error_text));
258        }
259
260        // Get response as text first to handle non-JSON responses
261        let response_text = response.text().await.map_err(|e| e.to_string())?;
262
263        // Try to parse as JSON first
264        if let Ok(value) = serde_json::from_str::<serde_json::Value>(&response_text) {
265            match serde_json::from_value::<CreateRuleBookResponse>(value.clone()) {
266                Ok(response) => return Ok(response),
267                Err(e) => {
268                    eprintln!("Failed to deserialize JSON response: {}", e);
269                    eprintln!("Raw response: {}", value);
270                }
271            }
272        }
273
274        // If JSON parsing failed, try to parse as plain text "id: <uuid>"
275        if response_text.starts_with("id: ") {
276            let id = response_text.trim_start_matches("id: ").trim().to_string();
277            return Ok(CreateRuleBookResponse { id });
278        }
279
280        Err(format!("Unexpected response format: {}", response_text))
281    }
282
283    async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
284        let encoded_uri = urlencoding::encode(uri);
285        let url = format!("{}/rules/{}", self.base_url, encoded_uri);
286
287        let response = self
288            .client
289            .delete(&url)
290            .send()
291            .await
292            .map_err(|e: ReqwestError| e.to_string())?;
293
294        let _response = self.handle_response_error(response).await?;
295
296        Ok(())
297    }
298
299    async fn list_agent_sessions(&self) -> Result<Vec<AgentSession>, String> {
300        let url = format!("{}/agents/sessions", self.base_url);
301
302        let response = self
303            .client
304            .get(&url)
305            .send()
306            .await
307            .map_err(|e: ReqwestError| e.to_string())?;
308
309        let response = self.handle_response_error(response).await?;
310
311        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
312        match serde_json::from_value::<Vec<AgentSession>>(value.clone()) {
313            Ok(response) => Ok(response),
314            Err(e) => {
315                eprintln!("Failed to deserialize response: {}", e);
316                eprintln!("Raw response: {}", value);
317                Err("Failed to deserialize response:".into())
318            }
319        }
320    }
321
322    async fn get_agent_session(&self, session_id: Uuid) -> Result<AgentSession, String> {
323        let url = format!("{}/agents/sessions/{}", self.base_url, session_id);
324
325        let response = self
326            .client
327            .get(&url)
328            .send()
329            .await
330            .map_err(|e: ReqwestError| e.to_string())?;
331
332        let response = self.handle_response_error(response).await?;
333
334        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
335
336        match serde_json::from_value::<AgentSession>(value.clone()) {
337            Ok(response) => Ok(response),
338            Err(e) => {
339                eprintln!("Failed to deserialize response: {}", e);
340                eprintln!("Raw response: {}", value);
341                Err("Failed to deserialize response:".into())
342            }
343        }
344    }
345
346    async fn get_agent_session_stats(&self, session_id: Uuid) -> Result<AgentSessionStats, String> {
347        let url = format!("{}/agents/sessions/{}/stats", self.base_url, session_id);
348
349        let response = self
350            .client
351            .get(&url)
352            .send()
353            .await
354            .map_err(|e: ReqwestError| e.to_string())?;
355
356        let response = self.handle_response_error(response).await?;
357
358        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
359
360        match serde_json::from_value::<AgentSessionStats>(value.clone()) {
361            Ok(response) => Ok(response),
362            Err(e) => {
363                eprintln!("Failed to deserialize response: {}", e);
364                eprintln!("Raw response: {}", value);
365                Err("Failed to deserialize response:".into())
366            }
367        }
368    }
369
370    async fn get_agent_checkpoint(&self, checkpoint_id: Uuid) -> Result<RunAgentOutput, String> {
371        let url = format!("{}/agents/checkpoints/{}", self.base_url, checkpoint_id);
372
373        let response = self
374            .client
375            .get(&url)
376            .send()
377            .await
378            .map_err(|e: ReqwestError| e.to_string())?;
379
380        let response = self.handle_response_error(response).await?;
381
382        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
383        match serde_json::from_value::<RunAgentOutput>(value.clone()) {
384            Ok(response) => Ok(response),
385            Err(e) => {
386                eprintln!("Failed to deserialize response: {}", e);
387                eprintln!("Raw response: {}", value);
388                Err("Failed to deserialize response:".into())
389            }
390        }
391    }
392
393    async fn get_agent_session_latest_checkpoint(
394        &self,
395        session_id: Uuid,
396    ) -> Result<RunAgentOutput, String> {
397        let url = format!(
398            "{}/agents/sessions/{}/checkpoints/latest",
399            self.base_url, session_id
400        );
401
402        let response = self
403            .client
404            .get(&url)
405            .send()
406            .await
407            .map_err(|e: ReqwestError| e.to_string())?;
408
409        let response = self.handle_response_error(response).await?;
410
411        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
412        match serde_json::from_value::<RunAgentOutput>(value.clone()) {
413            Ok(response) => Ok(response),
414            Err(e) => {
415                eprintln!("Failed to deserialize response: {}", e);
416                eprintln!("Raw response: {}", value);
417                Err("Failed to deserialize response:".into())
418            }
419        }
420    }
421
422    async fn chat_completion(
423        &self,
424        model: AgentModel,
425        messages: Vec<ChatMessage>,
426        tools: Option<Vec<Tool>>,
427    ) -> Result<ChatCompletionResponse, String> {
428        let url = format!("{}/agents/openai/v1/chat/completions", self.base_url);
429
430        let model_string = model.to_string();
431        let input = ChatCompletionRequest::new(model_string.clone(), messages, tools, None);
432
433        let response = self
434            .client
435            .post(&url)
436            .json(&input)
437            .send()
438            .await
439            .map_err(|e: ReqwestError| e.to_string())?;
440
441        let response = self.handle_response_error(response).await?;
442
443        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
444
445        if let Some(error_obj) = value.get("error") {
446            let error_message = if let Some(message) =
447                error_obj.get("message").and_then(|m| m.as_str())
448            {
449                message.to_string()
450            } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
451                format!("API error: {}", code)
452            } else if let Some(key) = error_obj.get("key").and_then(|k| k.as_str()) {
453                format!("API error: {}", key)
454            } else {
455                serde_json::to_string(error_obj).unwrap_or_else(|_| "Unknown API error".to_string())
456            };
457            return Err(error_message);
458        }
459
460        match serde_json::from_value::<ChatCompletionResponse>(value.clone()) {
461            Ok(response) => Ok(response),
462            Err(e) => {
463                eprintln!("Failed to deserialize response: {}", e);
464                eprintln!("Raw response: {}", value);
465                Err("Failed to deserialize response:".into())
466            }
467        }
468    }
469
470    async fn chat_completion_stream(
471        &self,
472        model: AgentModel,
473        messages: Vec<ChatMessage>,
474        tools: Option<Vec<Tool>>,
475        headers: Option<HeaderMap>,
476    ) -> Result<
477        (
478            std::pin::Pin<
479                Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
480            >,
481            Option<String>,
482        ),
483        String,
484    > {
485        let url = format!("{}/agents/openai/v1/chat/completions", self.base_url);
486
487        let model_string = model.to_string();
488        let input = ChatCompletionRequest::new(model_string.clone(), messages, tools, Some(true));
489
490        let response = self
491            .client
492            .post(&url)
493            .headers(headers.unwrap_or_default())
494            .json(&input)
495            .send()
496            .await
497            .map_err(|e: ReqwestError| e.to_string())?;
498
499        // Check content-type before processing
500        let content_type = response
501            .headers()
502            .get("content-type")
503            .and_then(|v| v.to_str().ok())
504            .unwrap_or("unknown");
505
506        // Extract x-request-id from headers
507        let request_id = response
508            .headers()
509            .get("x-request-id")
510            .and_then(|v| v.to_str().ok())
511            .map(|s| s.to_string());
512
513        // If content-type is not event-stream, it's likely an error message
514        if !content_type.contains("event-stream") && !content_type.contains("text/event-stream") {
515            let status = response.status();
516            let error_body = response
517                .text()
518                .await
519                .unwrap_or_else(|_| "Failed to read error body".to_string());
520
521            let error_message =
522                if let Ok(json) = serde_json::from_str::<serde_json::Value>(&error_body) {
523                    // Try ApiError format first (Stakpak API format)
524                    if let Ok(api_error) = serde_json::from_value::<ApiError>(json.clone()) {
525                        api_error.error.message
526                    } else if let Some(error_obj) = json.get("error") {
527                        // Generic error format
528                        if let Some(message) = error_obj.get("message").and_then(|m| m.as_str()) {
529                            message.to_string()
530                        } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
531                            format!("API error: {}", code)
532                        } else {
533                            error_body
534                        }
535                    } else {
536                        error_body
537                    }
538                } else {
539                    error_body
540                };
541
542            return Err(format!(
543                "Server returned non-stream response ({}): {}",
544                status, error_message
545            ));
546        }
547
548        let response = self.handle_response_error(response).await?;
549        let stream = response.bytes_stream().eventsource().map(move |event| {
550            event
551                .map_err(|_| ApiStreamError::Unknown("Failed to read response".to_string()))
552                .and_then(|event| match event.event.as_str() {
553                    "error" => Err(ApiStreamError::from(event.data)),
554                    _ => serde_json::from_str::<ChatCompletionStreamResponse>(&event.data).map_err(
555                        |_| {
556                            ApiStreamError::Unknown(
557                                "Failed to parse JSON from Anthropic response".to_string(),
558                            )
559                        },
560                    ),
561                })
562        });
563
564        Ok((Box::pin(stream), request_id))
565    }
566
567    async fn cancel_stream(&self, request_id: String) -> Result<(), String> {
568        let url = format!("{}/agents/requests/{}/cancel", self.base_url, request_id);
569        self.client
570            .post(&url)
571            .send()
572            .await
573            .map_err(|e: ReqwestError| e.to_string())?;
574
575        Ok(())
576    }
577
578    // async fn build_code_index(
579    //     &self,
580    //     input: &BuildCodeIndexInput,
581    // ) -> Result<BuildCodeIndexOutput, String> {
582    //     let url = format!("{}/commands/build_code_index", self.base_url,);
583
584    //     let response = self
585    //         .client
586    //         .post(&url)
587    //         .json(&input)
588    //         .send()
589    //         .await
590    //         .map_err(|e: ReqwestError| e.to_string())?;
591
592    //     let response = self.handle_response_error(response).await?;
593
594    //     let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
595    //     match serde_json::from_value::<BuildCodeIndexOutput>(value.clone()) {
596    //         Ok(response) => Ok(response),
597    //         Err(e) => {
598    //             eprintln!("Failed to deserialize response: {}", e);
599    //             eprintln!("Raw response: {}", value);
600    //             Err("Failed to deserialize response:".into())
601    //         }
602    //     }
603    // }
604
605    async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
606        self.call_mcp_tool(&ToolsCallParams {
607            name: "search_docs".to_string(),
608            arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
609        })
610        .await
611    }
612
613    async fn search_memory(&self, input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
614        self.call_mcp_tool(&ToolsCallParams {
615            name: "search_memory".to_string(),
616            arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
617        })
618        .await
619    }
620
621    async fn slack_read_messages(
622        &self,
623        input: &SlackReadMessagesRequest,
624    ) -> Result<Vec<Content>, String> {
625        self.call_mcp_tool(&ToolsCallParams {
626            name: "slack_read_messages".to_string(),
627            arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
628        })
629        .await
630    }
631
632    async fn slack_read_replies(
633        &self,
634        input: &SlackReadRepliesRequest,
635    ) -> Result<Vec<Content>, String> {
636        self.call_mcp_tool(&ToolsCallParams {
637            name: "slack_read_replies".to_string(),
638            arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
639        })
640        .await
641    }
642
643    async fn slack_send_message(
644        &self,
645        input: &SlackSendMessageRequest,
646    ) -> Result<Vec<Content>, String> {
647        // Note: The remote tool expects "markdown_text" but the struct has "mrkdwn_text".
648        // We need to map this correctly. The struct in models.rs has mrkdwn_text.
649        // The remote tool likely expects what was previously passed.
650        // In slack.rs, it was mapping "mrkdwn_text" to "markdown_text".
651        // So we should construct the arguments manually or use a custom serializer if we want to match exactly.
652        // However, since we are sending `input` which is `SlackSendMessageRequest`, let's check its definition.
653        // It has `mrkdwn_text`.
654        // The previous implementation in slack.rs did:
655        // arguments: json!({
656        //     "channel": channel,
657        //     "markdown_text": mrkdwn_text,
658        //     "thread_ts": thread_ts,
659        // }),
660        // So we need to replicate this mapping.
661
662        let arguments = json!({
663            "channel": input.channel,
664            "markdown_text": input.mrkdwn_text,
665            "thread_ts": input.thread_ts,
666        });
667
668        self.call_mcp_tool(&ToolsCallParams {
669            name: "slack_send_message".to_string(),
670            arguments,
671        })
672        .await
673    }
674
675    async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
676        let url = format!(
677            "{}/agents/sessions/checkpoints/{}/extract-memory",
678            self.base_url, checkpoint_id
679        );
680
681        let response = self
682            .client
683            .post(&url)
684            .send()
685            .await
686            .map_err(|e: ReqwestError| e.to_string())?;
687
688        let _ = self.handle_response_error(response).await?;
689        Ok(())
690    }
691}