1use crate::AgentProvider;
2use crate::models::*;
3use async_trait::async_trait;
4use eventsource_stream::Eventsource;
5use futures_util::Stream;
6use futures_util::StreamExt;
7use reqwest::header::HeaderMap;
8use reqwest::{Client as ReqwestClient, Error as ReqwestError, Response, header};
9use rmcp::model::Content;
10use rmcp::model::JsonRpcResponse;
11use serde::Deserialize;
12use serde_json::json;
13use stakpak_shared::models::integrations::openai::{
14 AgentModel, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse,
15 ChatMessage, Tool,
16};
17use stakpak_shared::tls_client::TlsClientConfig;
18use stakpak_shared::tls_client::create_tls_client;
19use uuid::Uuid;
20
21#[derive(Clone, Debug)]
22pub struct RemoteClient {
23 client: ReqwestClient,
24 base_url: String,
25}
26
27#[derive(Clone, Debug)]
28pub struct ClientConfig {
29 pub api_key: Option<String>,
30 pub api_endpoint: String,
31}
32
33#[derive(Deserialize)]
34struct ApiError {
35 error: ApiErrorDetail,
36}
37
38#[derive(Deserialize)]
39struct ApiErrorDetail {
40 key: String,
41 message: String,
42}
43
44impl RemoteClient {
45 async fn handle_response_error(&self, response: Response) -> Result<Response, String> {
46 if response.status().is_success() {
47 Ok(response)
48 } else {
49 match response.json::<ApiError>().await {
50 Ok(response) => {
51 if response.error.key == "EXCEEDED_API_LIMIT" {
52 Err(format!(
53 "{}.\n\nPlease top up your account at https://stakpak.dev/settings/billing to keep Stakpaking.",
54 response.error.message
55 ))
56 } else {
57 Err(response.error.message)
58 }
59 }
60 Err(e) => Err(e.to_string()),
61 }
62 }
63 }
64
65 async fn call_mcp_tool(&self, input: &ToolsCallParams) -> Result<Vec<Content>, String> {
66 let url = format!("{}/mcp", self.base_url);
67
68 let payload = json!({
69 "jsonrpc": "2.0",
70 "method": "tools/call",
71 "params": {
72 "name": input.name,
73 "arguments": input.arguments,
74 },
75 "id": Uuid::new_v4().to_string(),
76 });
77
78 let response = self
79 .client
80 .post(&url)
81 .json(&payload)
82 .send()
83 .await
84 .map_err(|e: ReqwestError| e.to_string())?;
85
86 let response = self.handle_response_error(response).await?;
87
88 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
89
90 match serde_json::from_value::<JsonRpcResponse<ToolsCallResponse>>(value.clone()) {
91 Ok(response) => Ok(response.result.content),
92 Err(e) => {
93 eprintln!("Failed to deserialize response: {}", e);
94 eprintln!("Raw response: {}", value);
95 Err("Failed to deserialize response:".into())
96 }
97 }
98 }
99
100 pub fn new(config: &ClientConfig) -> Result<Self, String> {
101 if config.api_key.is_none() {
102 return Err("API Key not found, please login".into());
103 }
104
105 let mut headers = header::HeaderMap::new();
106 headers.insert(
107 header::AUTHORIZATION,
108 header::HeaderValue::from_str(&format!("Bearer {}", config.api_key.clone().unwrap()))
109 .expect("Invalid API key format"),
110 );
111 headers.insert(
112 header::USER_AGENT,
113 header::HeaderValue::from_str(&format!("Stakpak/{}", env!("CARGO_PKG_VERSION")))
114 .expect("Invalid user agent format"),
115 );
116
117 let client = create_tls_client(
118 TlsClientConfig::default()
119 .with_headers(headers)
120 .with_timeout(std::time::Duration::from_secs(300)),
121 )?;
122
123 Ok(Self {
124 client,
125 base_url: config.api_endpoint.clone() + "/v1",
126 })
127 }
128}
129
130#[async_trait]
131impl AgentProvider for RemoteClient {
132 async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
133 let url = format!("{}/account", self.base_url);
134
135 let response = self
136 .client
137 .get(&url)
138 .send()
139 .await
140 .map_err(|e: ReqwestError| e.to_string())?;
141
142 let response = self.handle_response_error(response).await?;
143
144 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
145 match serde_json::from_value::<GetMyAccountResponse>(value.clone()) {
146 Ok(response) => Ok(response),
147 Err(e) => {
148 eprintln!("Failed to deserialize response: {}", e);
149 eprintln!("Raw response: {}", value);
150 Err("Failed to deserialize response:".into())
151 }
152 }
153 }
154
155 async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
156 let url = format!("{}/rules", self.base_url);
157
158 let response = self
159 .client
160 .get(&url)
161 .send()
162 .await
163 .map_err(|e: ReqwestError| e.to_string())?;
164
165 let response = self.handle_response_error(response).await?;
166
167 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
168 match serde_json::from_value::<ListRulebooksResponse>(value.clone()) {
169 Ok(response) => Ok(response.results),
170 Err(e) => {
171 eprintln!("Failed to deserialize response: {}", e);
172 eprintln!("Raw response: {}", value);
173 Err("Failed to deserialize response:".into())
174 }
175 }
176 }
177
178 async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
179 let encoded_uri = urlencoding::encode(uri);
181 let url = format!("{}/rules/{}", self.base_url, encoded_uri);
182
183 let response = self
184 .client
185 .get(&url)
186 .send()
187 .await
188 .map_err(|e: ReqwestError| e.to_string())?;
189
190 let response = self.handle_response_error(response).await?;
191
192 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
193 match serde_json::from_value::<RuleBook>(value.clone()) {
194 Ok(response) => Ok(response),
195 Err(e) => {
196 eprintln!("Failed to deserialize response: {}", e);
197 eprintln!("Raw response: {}", value);
198 Err("Failed to deserialize response:".into())
199 }
200 }
201 }
202
203 async fn create_rulebook(
204 &self,
205 uri: &str,
206 description: &str,
207 content: &str,
208 tags: Vec<String>,
209 visibility: Option<RuleBookVisibility>,
210 ) -> Result<CreateRuleBookResponse, String> {
211 let url = format!("{}/rules", self.base_url);
212
213 let input = CreateRuleBookInput {
214 uri: uri.to_string(),
215 description: description.to_string(),
216 content: content.to_string(),
217 tags,
218 visibility,
219 };
220
221 let response = self
222 .client
223 .post(&url)
224 .json(&input)
225 .send()
226 .await
227 .map_err(|e: ReqwestError| e.to_string())?;
228
229 if !response.status().is_success() {
231 let status = response.status();
232 let error_text = response
233 .text()
234 .await
235 .unwrap_or_else(|_| "Unknown error".to_string());
236 return Err(format!("API error ({}): {}", status, error_text));
237 }
238
239 let response_text = response.text().await.map_err(|e| e.to_string())?;
241
242 if let Ok(value) = serde_json::from_str::<serde_json::Value>(&response_text) {
244 match serde_json::from_value::<CreateRuleBookResponse>(value.clone()) {
245 Ok(response) => return Ok(response),
246 Err(e) => {
247 eprintln!("Failed to deserialize JSON response: {}", e);
248 eprintln!("Raw response: {}", value);
249 }
250 }
251 }
252
253 if response_text.starts_with("id: ") {
255 let id = response_text.trim_start_matches("id: ").trim().to_string();
256 return Ok(CreateRuleBookResponse { id });
257 }
258
259 Err(format!("Unexpected response format: {}", response_text))
260 }
261
262 async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
263 let encoded_uri = urlencoding::encode(uri);
264 let url = format!("{}/rules/{}", self.base_url, encoded_uri);
265
266 let response = self
267 .client
268 .delete(&url)
269 .send()
270 .await
271 .map_err(|e: ReqwestError| e.to_string())?;
272
273 let _response = self.handle_response_error(response).await?;
274
275 Ok(())
276 }
277
278 async fn list_agent_sessions(&self) -> Result<Vec<AgentSession>, String> {
279 let url = format!("{}/agents/sessions", self.base_url);
280
281 let response = self
282 .client
283 .get(&url)
284 .send()
285 .await
286 .map_err(|e: ReqwestError| e.to_string())?;
287
288 let response = self.handle_response_error(response).await?;
289
290 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
291 match serde_json::from_value::<Vec<AgentSession>>(value.clone()) {
292 Ok(response) => Ok(response),
293 Err(e) => {
294 eprintln!("Failed to deserialize response: {}", e);
295 eprintln!("Raw response: {}", value);
296 Err("Failed to deserialize response:".into())
297 }
298 }
299 }
300
301 async fn get_agent_session(&self, session_id: Uuid) -> Result<AgentSession, String> {
302 let url = format!("{}/agents/sessions/{}", self.base_url, session_id);
303
304 let response = self
305 .client
306 .get(&url)
307 .send()
308 .await
309 .map_err(|e: ReqwestError| e.to_string())?;
310
311 let response = self.handle_response_error(response).await?;
312
313 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
314
315 match serde_json::from_value::<AgentSession>(value.clone()) {
316 Ok(response) => Ok(response),
317 Err(e) => {
318 eprintln!("Failed to deserialize response: {}", e);
319 eprintln!("Raw response: {}", value);
320 Err("Failed to deserialize response:".into())
321 }
322 }
323 }
324
325 async fn get_agent_session_stats(&self, session_id: Uuid) -> Result<AgentSessionStats, String> {
326 let url = format!("{}/agents/sessions/{}/stats", self.base_url, session_id);
327
328 let response = self
329 .client
330 .get(&url)
331 .send()
332 .await
333 .map_err(|e: ReqwestError| e.to_string())?;
334
335 let response = self.handle_response_error(response).await?;
336
337 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
338
339 match serde_json::from_value::<AgentSessionStats>(value.clone()) {
340 Ok(response) => Ok(response),
341 Err(e) => {
342 eprintln!("Failed to deserialize response: {}", e);
343 eprintln!("Raw response: {}", value);
344 Err("Failed to deserialize response:".into())
345 }
346 }
347 }
348
349 async fn get_agent_checkpoint(&self, checkpoint_id: Uuid) -> Result<RunAgentOutput, String> {
350 let url = format!("{}/agents/checkpoints/{}", self.base_url, checkpoint_id);
351
352 let response = self
353 .client
354 .get(&url)
355 .send()
356 .await
357 .map_err(|e: ReqwestError| e.to_string())?;
358
359 let response = self.handle_response_error(response).await?;
360
361 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
362 match serde_json::from_value::<RunAgentOutput>(value.clone()) {
363 Ok(response) => Ok(response),
364 Err(e) => {
365 eprintln!("Failed to deserialize response: {}", e);
366 eprintln!("Raw response: {}", value);
367 Err("Failed to deserialize response:".into())
368 }
369 }
370 }
371
372 async fn get_agent_session_latest_checkpoint(
373 &self,
374 session_id: Uuid,
375 ) -> Result<RunAgentOutput, String> {
376 let url = format!(
377 "{}/agents/sessions/{}/checkpoints/latest",
378 self.base_url, session_id
379 );
380
381 let response = self
382 .client
383 .get(&url)
384 .send()
385 .await
386 .map_err(|e: ReqwestError| e.to_string())?;
387
388 let response = self.handle_response_error(response).await?;
389
390 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
391 match serde_json::from_value::<RunAgentOutput>(value.clone()) {
392 Ok(response) => Ok(response),
393 Err(e) => {
394 eprintln!("Failed to deserialize response: {}", e);
395 eprintln!("Raw response: {}", value);
396 Err("Failed to deserialize response:".into())
397 }
398 }
399 }
400
401 async fn chat_completion(
402 &self,
403 model: AgentModel,
404 messages: Vec<ChatMessage>,
405 tools: Option<Vec<Tool>>,
406 ) -> Result<ChatCompletionResponse, String> {
407 let url = format!("{}/agents/openai/v1/chat/completions", self.base_url);
408
409 let input = ChatCompletionRequest::new(model.to_string(), messages, tools, None);
410
411 let response = self
412 .client
413 .post(&url)
414 .json(&input)
415 .send()
416 .await
417 .map_err(|e: ReqwestError| e.to_string())?;
418
419 let response = self.handle_response_error(response).await?;
420
421 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
422
423 match serde_json::from_value::<ChatCompletionResponse>(value.clone()) {
424 Ok(response) => Ok(response),
425 Err(e) => {
426 eprintln!("Failed to deserialize response: {}", e);
427 eprintln!("Raw response: {}", value);
428 Err("Failed to deserialize response:".into())
429 }
430 }
431 }
432
433 async fn chat_completion_stream(
434 &self,
435 model: AgentModel,
436 messages: Vec<ChatMessage>,
437 tools: Option<Vec<Tool>>,
438 headers: Option<HeaderMap>,
439 ) -> Result<
440 (
441 std::pin::Pin<
442 Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
443 >,
444 Option<String>,
445 ),
446 String,
447 > {
448 let url = format!("{}/agents/openai/v1/chat/completions", self.base_url);
449
450 let input = ChatCompletionRequest::new(model.to_string(), messages, tools, Some(true));
451
452 let response = self
453 .client
454 .post(&url)
455 .headers(headers.unwrap_or_default())
456 .json(&input)
457 .send()
458 .await
459 .map_err(|e: ReqwestError| e.to_string())?;
460
461 let content_type = response
463 .headers()
464 .get("content-type")
465 .and_then(|v| v.to_str().ok())
466 .unwrap_or("unknown");
467
468 let request_id = response
470 .headers()
471 .get("x-request-id")
472 .and_then(|v| v.to_str().ok())
473 .map(|s| s.to_string());
474
475 if !content_type.contains("event-stream") && !content_type.contains("text/event-stream") {
477 let status = response.status();
478 let error_body = response
479 .text()
480 .await
481 .unwrap_or_else(|_| "Failed to read error body".to_string());
482 return Err(format!(
483 "Server returned non-stream response ({}): {}",
484 status, error_body
485 ));
486 }
487
488 let response = self.handle_response_error(response).await?;
489 let stream = response.bytes_stream().eventsource().map(|event| {
490 event
491 .map_err(|err| {
492 eprintln!("stream: failed to read response: {:?}", err);
493 ApiStreamError::Unknown("Failed to read response".to_string())
494 })
495 .and_then(|event| match event.event.as_str() {
496 "error" => Err(ApiStreamError::from(event.data)),
497 _ => serde_json::from_str::<ChatCompletionStreamResponse>(&event.data).map_err(
498 |_| {
499 ApiStreamError::Unknown(
500 "Failed to parse JSON from Anthropic response".to_string(),
501 )
502 },
503 ),
504 })
505 });
506
507 Ok((Box::pin(stream), request_id))
508 }
509
510 async fn cancel_stream(&self, request_id: String) -> Result<(), String> {
511 let url = format!("{}/agents/requests/{}/cancel", self.base_url, request_id);
512 self.client
513 .post(&url)
514 .send()
515 .await
516 .map_err(|e: ReqwestError| e.to_string())?;
517
518 Ok(())
519 }
520
521 async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
549 self.call_mcp_tool(&ToolsCallParams {
550 name: "search_docs".to_string(),
551 arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
552 })
553 .await
554 }
555
556 async fn search_memory(&self, input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
557 self.call_mcp_tool(&ToolsCallParams {
558 name: "search_memory".to_string(),
559 arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
560 })
561 .await
562 }
563
564 async fn slack_read_messages(
565 &self,
566 input: &SlackReadMessagesRequest,
567 ) -> Result<Vec<Content>, String> {
568 self.call_mcp_tool(&ToolsCallParams {
569 name: "slack_read_messages".to_string(),
570 arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
571 })
572 .await
573 }
574
575 async fn slack_read_replies(
576 &self,
577 input: &SlackReadRepliesRequest,
578 ) -> Result<Vec<Content>, String> {
579 self.call_mcp_tool(&ToolsCallParams {
580 name: "slack_read_replies".to_string(),
581 arguments: serde_json::to_value(input).map_err(|e| e.to_string())?,
582 })
583 .await
584 }
585
586 async fn slack_send_message(
587 &self,
588 input: &SlackSendMessageRequest,
589 ) -> Result<Vec<Content>, String> {
590 let arguments = json!({
606 "channel": input.channel,
607 "markdown_text": input.mrkdwn_text,
608 "thread_ts": input.thread_ts,
609 });
610
611 self.call_mcp_tool(&ToolsCallParams {
612 name: "slack_send_message".to_string(),
613 arguments,
614 })
615 .await
616 }
617
618 async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
619 let url = format!(
620 "{}/agents/sessions/checkpoints/{}/extract-memory",
621 self.base_url, checkpoint_id
622 );
623
624 let response = self
625 .client
626 .post(&url)
627 .send()
628 .await
629 .map_err(|e: ReqwestError| e.to_string())?;
630
631 let _ = self.handle_response_error(response).await?;
632 Ok(())
633 }
634}