Skip to main content

agent_diva_cli/
client.rs

1use agent_diva_agent::AgentEvent;
2use anyhow::Result;
3use eventsource_stream::Eventsource;
4use futures::StreamExt;
5use reqwest::Client;
6use serde::Deserialize;
7use serde_json::Value;
8use tokio::sync::mpsc;
9
10pub struct ApiClient {
11    client: Client,
12    base_url: String,
13}
14
15#[derive(Deserialize)]
16struct ToolStartEvent {
17    name: String,
18    #[serde(alias = "args")]
19    args_preview: String,
20    id: String,
21}
22
23#[derive(Deserialize)]
24struct ToolFinishEvent {
25    name: String,
26    result: String,
27    error: bool,
28    id: String,
29}
30
31#[derive(Deserialize)]
32struct ToolDeltaEvent {
33    name: String,
34    delta: String,
35}
36
37impl ApiClient {
38    pub fn new(base_url: Option<String>) -> Self {
39        Self {
40            client: Client::new(),
41            base_url: base_url.unwrap_or_else(|| "http://localhost:3000/api".to_string()),
42        }
43    }
44
45    pub async fn chat_with_target(
46        &self,
47        message: String,
48        channel: Option<&str>,
49        chat_id: Option<&str>,
50        event_tx: mpsc::UnboundedSender<AgentEvent>,
51    ) -> Result<()> {
52        let url = format!("{}/chat", self.base_url);
53        let mut payload = serde_json::json!({ "message": message });
54        if let Some(channel) = channel {
55            payload["channel"] = serde_json::Value::String(channel.to_string());
56        }
57        if let Some(chat_id) = chat_id {
58            payload["chat_id"] = serde_json::Value::String(chat_id.to_string());
59        }
60        let response = self.client.post(&url).json(&payload).send().await?;
61
62        if !response.status().is_success() {
63            anyhow::bail!("Server returned error: {}", response.status());
64        }
65
66        let mut stream = response.bytes_stream().eventsource();
67
68        while let Some(event) = stream.next().await {
69            match event {
70                Ok(event) => match event.event.as_str() {
71                    "delta" => {
72                        let _ = event_tx.send(AgentEvent::AssistantDelta { text: event.data });
73                    }
74                    "final" => {
75                        let _ = event_tx.send(AgentEvent::FinalResponse {
76                            content: event.data,
77                        });
78                    }
79                    "tool_start" => {
80                        if let Ok(data) = serde_json::from_str::<ToolStartEvent>(&event.data) {
81                            let _ = event_tx.send(AgentEvent::ToolCallStarted {
82                                name: data.name,
83                                args_preview: data.args_preview,
84                                call_id: data.id,
85                            });
86                        }
87                    }
88                    "tool_finish" => {
89                        if let Ok(data) = serde_json::from_str::<ToolFinishEvent>(&event.data) {
90                            let _ = event_tx.send(AgentEvent::ToolCallFinished {
91                                name: data.name,
92                                result: data.result,
93                                is_error: data.error,
94                                call_id: data.id,
95                            });
96                        }
97                    }
98                    "tool_delta" => {
99                        if let Ok(data) = serde_json::from_str::<ToolDeltaEvent>(&event.data) {
100                            let _ = event_tx.send(AgentEvent::ToolCallDelta {
101                                name: Some(data.name),
102                                args_delta: data.delta,
103                            });
104                        }
105                    }
106                    "error" => {
107                        let _ = event_tx.send(AgentEvent::Error {
108                            message: event.data,
109                        });
110                    }
111                    _ => {}
112                },
113                Err(e) => {
114                    let _ = event_tx.send(AgentEvent::Error {
115                        message: e.to_string(),
116                    });
117                }
118            }
119        }
120        Ok(())
121    }
122
123    pub async fn stop(&self, channel: Option<&str>, chat_id: Option<&str>) -> Result<bool> {
124        let url = format!("{}/chat/stop", self.base_url);
125        let mut payload = serde_json::json!({});
126        if let Some(channel) = channel {
127            payload["channel"] = serde_json::Value::String(channel.to_string());
128        }
129        if let Some(chat_id) = chat_id {
130            payload["chat_id"] = serde_json::Value::String(chat_id.to_string());
131        }
132
133        let response = self.client.post(&url).json(&payload).send().await?;
134        if !response.status().is_success() {
135            anyhow::bail!("Server returned error: {}", response.status());
136        }
137
138        let body: Value = response.json().await?;
139        if body.get("status").and_then(|v| v.as_str()) != Some("ok") {
140            let msg = body
141                .get("message")
142                .and_then(|v| v.as_str())
143                .unwrap_or("unknown error");
144            anyhow::bail!("Stop failed: {}", msg);
145        }
146        Ok(body
147            .get("stopped")
148            .and_then(|v| v.as_bool())
149            .unwrap_or(true))
150    }
151}