1use crate::AgentProvider;
2use crate::models::*;
3use async_trait::async_trait;
4use eventsource_stream::Eventsource;
5use futures_util::Stream;
6use futures_util::StreamExt;
7use reqwest::header::HeaderMap;
8use reqwest::{Client as ReqwestClient, Error as ReqwestError, Response, header};
9use rmcp::model::Content;
10use rmcp::model::JsonRpcResponse;
11use serde::Deserialize;
12use serde_json::json;
13use stakpak_shared::models::integrations::openai::{
14 AgentModel, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse,
15 ChatMessage, Tool,
16};
17use stakpak_shared::tls_client::TlsClientConfig;
18use stakpak_shared::tls_client::create_tls_client;
19use uuid::Uuid;
20
21#[derive(Clone, Debug)]
22pub struct RemoteClient {
23 client: ReqwestClient,
24 base_url: String,
25}
26
27#[derive(Clone, Debug)]
28pub struct ClientConfig {
29 pub api_key: Option<String>,
30 pub api_endpoint: String,
31}
32
33#[derive(Deserialize)]
34struct ApiError {
35 error: ApiErrorDetail,
36}
37
38#[derive(Deserialize)]
39struct ApiErrorDetail {
40 key: String,
41 message: String,
42}
43
44impl RemoteClient {
45 async fn handle_response_error(&self, response: Response) -> Result<Response, String> {
46 if response.status().is_success() {
47 Ok(response)
48 } else {
49 let error_body = response
50 .text()
51 .await
52 .unwrap_or_else(|_| "Failed to read error body".to_string());
53
54 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&error_body) {
55 if let Ok(api_error) = serde_json::from_value::<ApiError>(json.clone()) {
56 if api_error.error.key == "EXCEEDED_API_LIMIT" {
57 return Err(format!(
58 "{}.\n\nPlease top up your account at https://stakpak.dev/settings/billing to keep Stakpaking.",
59 api_error.error.message
60 ));
61 } else {
62 return Err(api_error.error.message);
63 }
64 }
65
66 if let Some(error_obj) = json.get("error") {
67 let error_message =
68 if let Some(message) = error_obj.get("message").and_then(|m| m.as_str()) {
69 message.to_string()
70 } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
71 format!("API error: {}", code)
72 } else if let Some(key) = error_obj.get("key").and_then(|k| k.as_str()) {
73 format!("API error: {}", key)
74 } else {
75 serde_json::to_string(error_obj)
76 .unwrap_or_else(|_| "Unknown API error".to_string())
77 };
78 return Err(error_message);
79 }
80 }
81
82 Err(error_body)
83 }
84 }
85
86 async fn call_mcp_tool(&self, input: &ToolsCallParams) -> Result<Vec<Content>, String> {
87 let url = format!("{}/mcp", self.base_url);
88
89 let payload = json!({
90 "jsonrpc": "2.0",
91 "method": "tools/call",
92 "params": {
93 "name": input.name,
94 "arguments": input.arguments,
95 },
96 "id": Uuid::new_v4().to_string(),
97 });
98
99 let response = self
100 .client
101 .post(&url)
102 .json(&payload)
103 .send()
104 .await
105 .map_err(|e: ReqwestError| e.to_string())?;
106
107 let response = self.handle_response_error(response).await?;
108
109 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
110
111 match serde_json::from_value::<JsonRpcResponse<ToolsCallResponse>>(value.clone()) {
112 Ok(response) => Ok(response.result.content),
113 Err(_) => {
114 Err("Failed to deserialize response:".into())
117 }
118 }
119 }
120
121 pub fn new(config: &ClientConfig) -> Result<Self, String> {
122 if config.api_key.is_none() {
123 return Err("API Key not found, please login".into());
124 }
125
126 let mut headers = header::HeaderMap::new();
127 headers.insert(
128 header::AUTHORIZATION,
129 header::HeaderValue::from_str(&format!("Bearer {}", config.api_key.clone().unwrap()))
130 .expect("Invalid API key format"),
131 );
132 headers.insert(
133 header::USER_AGENT,
134 header::HeaderValue::from_str(&format!("Stakpak/{}", env!("CARGO_PKG_VERSION")))
135 .expect("Invalid user agent format"),
136 );
137
138 let client = create_tls_client(
139 TlsClientConfig::default()
140 .with_headers(headers)
141 .with_timeout(std::time::Duration::from_secs(300)),
142 )?;
143
144 Ok(Self {
145 client,
146 base_url: config.api_endpoint.clone() + "/v1",
147 })
148 }
149}
150
151#[async_trait]
152impl AgentProvider for RemoteClient {
153 async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
154 let url = format!("{}/account", self.base_url);
155
156 let response = self
157 .client
158 .get(&url)
159 .send()
160 .await
161 .map_err(|e: ReqwestError| e.to_string())?;
162
163 let response = self.handle_response_error(response).await?;
164
165 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
166 match serde_json::from_value::<GetMyAccountResponse>(value.clone()) {
167 Ok(response) => Ok(response),
168 Err(e) => {
169 eprintln!("Failed to deserialize response: {}", e);
170 eprintln!("Raw response: {}", value);
171 Err("Failed to deserialize response:".into())
172 }
173 }
174 }
175
176 async fn get_billing_info(
177 &self,
178 account_username: &str,
179 ) -> Result<stakpak_shared::models::billing::BillingResponse, String> {
180 let base = self.base_url.trim_end_matches("/v1");
183 let url = format!("{}/v2/{}/billing", base, account_username);
184
185 let response = self
186 .client
187 .get(&url)
188 .send()
189 .await
190 .map_err(|e: ReqwestError| e.to_string())?;
191
192 let response = self.handle_response_error(response).await?;
193
194 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
195 match serde_json::from_value::<stakpak_shared::models::billing::BillingResponse>(
196 value.clone(),
197 ) {
198 Ok(response) => Ok(response),
199 Err(e) => {
200 let error_msg = format!("Failed to deserialize billing response: {}", e);
201 Err(error_msg)
202 }
203 }
204 }
205
206 async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
207 let url = format!("{}/rules", self.base_url);
208
209 let response = self
210 .client
211 .get(&url)
212 .send()
213 .await
214 .map_err(|e: ReqwestError| e.to_string())?;
215
216 let response = self.handle_response_error(response).await?;
217
218 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
219 match serde_json::from_value::<ListRulebooksResponse>(value.clone()) {
220 Ok(response) => Ok(response.results),
221 Err(e) => {
222 eprintln!("Failed to deserialize response: {}", e);
223 eprintln!("Raw response: {}", value);
224 Err("Failed to deserialize response:".into())
225 }
226 }
227 }
228
229 async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
230 let encoded_uri = urlencoding::encode(uri);
232 let url = format!("{}/rules/{}", self.base_url, encoded_uri);
233
234 let response = self
235 .client
236 .get(&url)
237 .send()
238 .await
239 .map_err(|e: ReqwestError| e.to_string())?;
240
241 let response = self.handle_response_error(response).await?;
242
243 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
244 match serde_json::from_value::<RuleBook>(value.clone()) {
245 Ok(response) => Ok(response),
246 Err(e) => {
247 eprintln!("Failed to deserialize response: {}", e);
248 eprintln!("Raw response: {}", value);
249 Err("Failed to deserialize response:".into())
250 }
251 }
252 }
253
254 async fn create_rulebook(
255 &self,
256 uri: &str,
257 description: &str,
258 content: &str,
259 tags: Vec<String>,
260 visibility: Option<RuleBookVisibility>,
261 ) -> Result<CreateRuleBookResponse, String> {
262 let url = format!("{}/rules", self.base_url);
263
264 let input = CreateRuleBookInput {
265 uri: uri.to_string(),
266 description: description.to_string(),
267 content: content.to_string(),
268 tags,
269 visibility,
270 };
271
272 let response = self
273 .client
274 .post(&url)
275 .json(&input)
276 .send()
277 .await
278 .map_err(|e: ReqwestError| e.to_string())?;
279
280 if !response.status().is_success() {
282 let status = response.status();
283 let error_text = response
284 .text()
285 .await
286 .unwrap_or_else(|_| "Unknown error".to_string());
287 return Err(format!("API error ({}): {}", status, error_text));
288 }
289
290 let response_text = response.text().await.map_err(|e| e.to_string())?;
292
293 if let Ok(value) = serde_json::from_str::<serde_json::Value>(&response_text) {
295 match serde_json::from_value::<CreateRuleBookResponse>(value.clone()) {
296 Ok(response) => return Ok(response),
297 Err(e) => {
298 eprintln!("Failed to deserialize JSON response: {}", e);
299 eprintln!("Raw response: {}", value);
300 }
301 }
302 }
303
304 if response_text.starts_with("id: ") {
306 let id = response_text.trim_start_matches("id: ").trim().to_string();
307 return Ok(CreateRuleBookResponse { id });
308 }
309
310 Err(format!("Unexpected response format: {}", response_text))
311 }
312
313 async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
314 let encoded_uri = urlencoding::encode(uri);
315 let url = format!("{}/rules/{}", self.base_url, encoded_uri);
316
317 let response = self
318 .client
319 .delete(&url)
320 .send()
321 .await
322 .map_err(|e: ReqwestError| e.to_string())?;
323
324 let _response = self.handle_response_error(response).await?;
325
326 Ok(())
327 }
328
329 async fn list_agent_sessions(&self) -> Result<Vec<AgentSession>, String> {
330 let url = format!("{}/agents/sessions", self.base_url);
331
332 let response = self
333 .client
334 .get(&url)
335 .send()
336 .await
337 .map_err(|e: ReqwestError| e.to_string())?;
338
339 let response = self.handle_response_error(response).await?;
340
341 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
342 match serde_json::from_value::<Vec<AgentSession>>(value.clone()) {
343 Ok(response) => Ok(response),
344 Err(e) => {
345 eprintln!("Failed to deserialize response: {}", e);
346 eprintln!("Raw response: {}", value);
347 Err("Failed to deserialize response:".into())
348 }
349 }
350 }
351
352 async fn get_agent_session(&self, session_id: Uuid) -> Result<AgentSession, String> {
353 let url = format!("{}/agents/sessions/{}", self.base_url, session_id);
354
355 let response = self
356 .client
357 .get(&url)
358 .send()
359 .await
360 .map_err(|e: ReqwestError| e.to_string())?;
361
362 let response = self.handle_response_error(response).await?;
363
364 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
365
366 match serde_json::from_value::<AgentSession>(value.clone()) {
367 Ok(response) => Ok(response),
368 Err(e) => {
369 eprintln!("Failed to deserialize response: {}", e);
370 eprintln!("Raw response: {}", value);
371 Err("Failed to deserialize response:".into())
372 }
373 }
374 }
375
376 async fn get_agent_session_stats(&self, session_id: Uuid) -> Result<AgentSessionStats, String> {
377 let url = format!("{}/agents/sessions/{}/stats", self.base_url, session_id);
378
379 let response = self
380 .client
381 .get(&url)
382 .send()
383 .await
384 .map_err(|e: ReqwestError| e.to_string())?;
385
386 let response = self.handle_response_error(response).await?;
387
388 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
389
390 match serde_json::from_value::<AgentSessionStats>(value.clone()) {
391 Ok(response) => Ok(response),
392 Err(e) => {
393 eprintln!("Failed to deserialize response: {}", e);
394 eprintln!("Raw response: {}", value);
395 Err("Failed to deserialize response:".into())
396 }
397 }
398 }
399
400 async fn get_agent_checkpoint(&self, checkpoint_id: Uuid) -> Result<RunAgentOutput, String> {
401 let url = format!("{}/agents/checkpoints/{}", self.base_url, checkpoint_id);
402
403 let response = self
404 .client
405 .get(&url)
406 .send()
407 .await
408 .map_err(|e: ReqwestError| e.to_string())?;
409
410 let response = self.handle_response_error(response).await?;
411
412 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
413 match serde_json::from_value::<RunAgentOutput>(value.clone()) {
414 Ok(response) => Ok(response),
415 Err(e) => {
416 eprintln!("Failed to deserialize response: {}", e);
417 eprintln!("Raw response: {}", value);
418 Err("Failed to deserialize response:".into())
419 }
420 }
421 }
422
423 async fn get_agent_session_latest_checkpoint(
424 &self,
425 session_id: Uuid,
426 ) -> Result<RunAgentOutput, String> {
427 let url = format!(
428 "{}/agents/sessions/{}/checkpoints/latest",
429 self.base_url, session_id
430 );
431
432 let response = self
433 .client
434 .get(&url)
435 .send()
436 .await
437 .map_err(|e: ReqwestError| e.to_string())?;
438
439 let response = self.handle_response_error(response).await?;
440
441 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
442 match serde_json::from_value::<RunAgentOutput>(value.clone()) {
443 Ok(response) => Ok(response),
444 Err(e) => {
445 eprintln!("Failed to deserialize response: {}", e);
446 eprintln!("Raw response: {}", value);
447 Err("Failed to deserialize response:".into())
448 }
449 }
450 }
451
452 async fn chat_completion(
453 &self,
454 model: AgentModel,
455 messages: Vec<ChatMessage>,
456 tools: Option<Vec<Tool>>,
457 ) -> Result<ChatCompletionResponse, String> {
458 let url = format!("{}/agents/openai/v1/chat/completions", self.base_url);
459
460 let model_string = model.to_string();
461 let input = ChatCompletionRequest::new(model_string.clone(), messages, tools, None);
462
463 let response = self
464 .client
465 .post(&url)
466 .json(&input)
467 .send()
468 .await
469 .map_err(|e: ReqwestError| e.to_string())?;
470
471 let response = self.handle_response_error(response).await?;
472
473 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
474
475 if let Some(error_obj) = value.get("error") {
476 let error_message = if let Some(message) =
477 error_obj.get("message").and_then(|m| m.as_str())
478 {
479 message.to_string()
480 } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
481 format!("API error: {}", code)
482 } else if let Some(key) = error_obj.get("key").and_then(|k| k.as_str()) {
483 format!("API error: {}", key)
484 } else {
485 serde_json::to_string(error_obj).unwrap_or_else(|_| "Unknown API error".to_string())
486 };
487 return Err(error_message);
488 }
489
490 match serde_json::from_value::<ChatCompletionResponse>(value.clone()) {
491 Ok(response) => Ok(response),
492 Err(e) => {
493 eprintln!("Failed to deserialize response: {}", e);
494 eprintln!("Raw response: {}", value);
495 Err("Failed to deserialize response:".into())
496 }
497 }
498 }
499
500 async fn chat_completion_stream(
501 &self,
502 model: AgentModel,
503 messages: Vec<ChatMessage>,
504 tools: Option<Vec<Tool>>,
505 headers: Option<HeaderMap>,
506 ) -> Result<
507 (
508 std::pin::Pin<
509 Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
510 >,
511 Option<String>,
512 ),
513 String,
514 > {
515 let url = format!("{}/agents/openai/v1/chat/completions", self.base_url);
516
517 let model_string = model.to_string();
518 let input = ChatCompletionRequest::new(model_string.clone(), messages, tools, Some(true));
519
520 let response = self
521 .client
522 .post(&url)
523 .headers(headers.unwrap_or_default())
524 .json(&input)
525 .send()
526 .await
527 .map_err(|e: ReqwestError| e.to_string())?;
528
529 let content_type = response
531 .headers()
532 .get("content-type")
533 .and_then(|v| v.to_str().ok())
534 .unwrap_or("unknown");
535
536 let request_id = response
538 .headers()
539 .get("x-request-id")
540 .and_then(|v| v.to_str().ok())
541 .map(|s| s.to_string());
542
543 if !content_type.contains("event-stream") && !content_type.contains("text/event-stream") {
545 let status = response.status();
546 let error_body = response
547 .text()
548 .await
549 .unwrap_or_else(|_| "Failed to read error body".to_string());
550
551 let error_message =
552 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&error_body) {
553 if let Ok(api_error) = serde_json::from_value::<ApiError>(json.clone()) {
555 api_error.error.message
556 } else if let Some(error_obj) = json.get("error") {
557 if let Some(message) = error_obj.get("message").and_then(|m| m.as_str()) {
559 message.to_string()
560 } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
561 format!("API error: {}", code)
562 } else {
563 error_body
564 }
565 } else {
566 error_body
567 }
568 } else {
569 error_body
570 };
571
572 return Err(format!(
573 "Server returned non-stream response ({}): {}",
574 status, error_message
575 ));
576 }
577
578 let response = self.handle_response_error(response).await?;
579 let stream = response.bytes_stream().eventsource().map(move |event| {
580 event
581 .map_err(|_| ApiStreamError::Unknown("Failed to read response".to_string()))
582 .and_then(|event| match event.event.as_str() {
583 "error" => Err(ApiStreamError::from(event.data)),
584 _ => serde_json::from_str::<ChatCompletionStreamResponse>(&event.data).map_err(
585 |_| {
586 ApiStreamError::Unknown(
587 "Failed to parse JSON from Anthropic response".to_string(),
588 )
589 },
590 ),
591 })
592 });
593
594 Ok((Box::pin(stream), request_id))
595 }
596
597 async fn cancel_stream(&self, request_id: String) -> Result<(), String> {
598 let url = format!("{}/agents/requests/{}/cancel", self.base_url, request_id);
599 self.client
600 .post(&url)
601 .send()
602 .await
603 .map_err(|e: ReqwestError| e.to_string())?;
604
605 Ok(())
606 }
607
608 async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
636 self.call_mcp_tool(&ToolsCallParams {
637 name: "search_docs".to_string(),
638 arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
639 })
640 .await
641 }
642
643 async fn search_memory(&self, input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
644 self.call_mcp_tool(&ToolsCallParams {
645 name: "search_memory".to_string(),
646 arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
647 })
648 .await
649 }
650
651 async fn slack_read_messages(
652 &self,
653 input: &SlackReadMessagesRequest,
654 ) -> Result<Vec<Content>, String> {
655 self.call_mcp_tool(&ToolsCallParams {
656 name: "slack_read_messages".to_string(),
657 arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
658 })
659 .await
660 }
661
662 async fn slack_read_replies(
663 &self,
664 input: &SlackReadRepliesRequest,
665 ) -> Result<Vec<Content>, String> {
666 self.call_mcp_tool(&ToolsCallParams {
667 name: "slack_read_replies".to_string(),
668 arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
669 })
670 .await
671 }
672
673 async fn slack_send_message(
674 &self,
675 input: &SlackSendMessageRequest,
676 ) -> Result<Vec<Content>, String> {
677 let arguments = json!({
693 "channel": input.channel,
694 "markdown_text": input.mrkdwn_text,
695 "thread_ts": input.thread_ts,
696 });
697
698 self.call_mcp_tool(&ToolsCallParams {
699 name: "slack_send_message".to_string(),
700 arguments,
701 })
702 .await
703 }
704
705 async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
706 let url = format!(
707 "{}/agents/sessions/checkpoints/{}/extract-memory",
708 self.base_url, checkpoint_id
709 );
710
711 let response = self
712 .client
713 .post(&url)
714 .send()
715 .await
716 .map_err(|e: ReqwestError| e.to_string())?;
717
718 let _ = self.handle_response_error(response).await?;
719 Ok(())
720 }
721}