Skip to main content

agent_diva_providers/
base.rs

1//! Base trait for LLM providers
2
3use async_trait::async_trait;
4use futures::stream::{self, Stream};
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6use std::collections::HashMap;
7use std::pin::Pin;
8use thiserror::Error;
9
10/// Error type for provider operations
11#[derive(Error, Debug)]
12pub enum ProviderError {
13    #[error("HTTP request failed: {0}")]
14    HttpError(#[from] reqwest::Error),
15
16    #[error("JSON parsing failed: {0}")]
17    JsonError(#[from] serde_json::Error),
18
19    #[error("Invalid response: {0}")]
20    InvalidResponse(String),
21
22    #[error("API error: {0}")]
23    ApiError(String),
24
25    #[error("Configuration error: {0}")]
26    ConfigError(String),
27}
28
29pub type ProviderResult<T> = Result<T, ProviderError>;
30
31pub type ProviderEventStream = Pin<Box<dyn Stream<Item = ProviderResult<LLMStreamEvent>> + Send>>;
32
33/// A tool call request from the LLM
34#[derive(Debug, Clone)]
35pub struct ToolCallRequest {
36    pub id: String,
37    pub call_type: String,
38    pub name: String,
39    pub arguments: HashMap<String, serde_json::Value>,
40}
41
42impl Serialize for ToolCallRequest {
43    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
44    where
45        S: Serializer,
46    {
47        use serde::ser::Error as _;
48        use serde::ser::SerializeStruct;
49
50        #[derive(Serialize)]
51        struct Function<'a> {
52            name: &'a str,
53            arguments: String,
54        }
55
56        let arguments = serde_json::to_string(&self.arguments).map_err(|e| {
57            S::Error::custom(format!(
58                "failed to serialize tool call arguments for {}: {}",
59                self.name, e
60            ))
61        })?;
62
63        let mut state = serializer.serialize_struct("ToolCallRequest", 3)?;
64        state.serialize_field("id", &self.id)?;
65        state.serialize_field("type", &self.call_type)?;
66        state.serialize_field(
67            "function",
68            &Function {
69                name: &self.name,
70                arguments,
71            },
72        )?;
73        state.end()
74    }
75}
76
77impl<'de> Deserialize<'de> for ToolCallRequest {
78    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
79    where
80        D: Deserializer<'de>,
81    {
82        #[derive(Deserialize)]
83        struct Function {
84            name: String,
85            arguments: serde_json::Value,
86        }
87
88        #[derive(Deserialize)]
89        struct Helper {
90            id: String,
91            #[serde(rename = "type")]
92            call_type: String,
93            #[serde(default)]
94            function: Option<Function>,
95            #[serde(default)]
96            name: Option<String>,
97            #[serde(default)]
98            arguments: Option<serde_json::Value>,
99        }
100
101        fn normalize_arguments(value: serde_json::Value) -> HashMap<String, serde_json::Value> {
102            match value {
103                serde_json::Value::String(raw) => serde_json::from_str::<
104                    HashMap<String, serde_json::Value>,
105                >(&raw)
106                .unwrap_or_else(|_| {
107                    let mut map = HashMap::new();
108                    map.insert("raw".to_string(), serde_json::Value::String(raw));
109                    map
110                }),
111                serde_json::Value::Object(map) => map.into_iter().collect(),
112                _ => HashMap::new(),
113            }
114        }
115
116        let helper = Helper::deserialize(deserializer)?;
117        if let Some(function) = helper.function {
118            return Ok(Self {
119                id: helper.id,
120                call_type: helper.call_type,
121                name: function.name,
122                arguments: normalize_arguments(function.arguments),
123            });
124        }
125
126        let name = helper
127            .name
128            .ok_or_else(|| serde::de::Error::missing_field("function or name"))?;
129        let arguments = helper
130            .arguments
131            .map(normalize_arguments)
132            .unwrap_or_default();
133
134        Ok(Self {
135            id: helper.id,
136            call_type: helper.call_type,
137            name,
138            arguments,
139        })
140    }
141}
142
143/// Response from an LLM provider
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct LLMResponse {
146    pub content: Option<String>,
147    #[serde(default)]
148    pub tool_calls: Vec<ToolCallRequest>,
149    #[serde(default = "default_finish_reason")]
150    pub finish_reason: String,
151    #[serde(default)]
152    pub usage: HashMap<String, i64>,
153    #[serde(default)]
154    pub reasoning_content: Option<String>,
155}
156
157fn default_finish_reason() -> String {
158    "stop".to_string()
159}
160
161impl LLMResponse {
162    /// Check if response contains tool calls
163    pub fn has_tool_calls(&self) -> bool {
164        !self.tool_calls.is_empty()
165    }
166}
167
168/// Streaming event emitted by LLM providers
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub enum LLMStreamEvent {
171    /// Incremental assistant text output
172    TextDelta(String),
173    /// Incremental reasoning content
174    ReasoningDelta(String),
175    /// Incremental tool-call metadata (reserved for advanced UIs)
176    ToolCallDelta {
177        index: usize,
178        id: Option<String>,
179        name: Option<String>,
180        arguments_delta: Option<String>,
181    },
182    /// Final completed response
183    Completed(LLMResponse),
184}
185
186/// A message in the chat conversation
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct Message {
189    pub role: String,
190    pub content: String,
191    #[serde(skip_serializing_if = "Option::is_none")]
192    pub name: Option<String>,
193    #[serde(skip_serializing_if = "Option::is_none")]
194    pub tool_call_id: Option<String>,
195    #[serde(skip_serializing_if = "Option::is_none")]
196    pub tool_calls: Option<Vec<ToolCallRequest>>,
197    #[serde(skip_serializing_if = "Option::is_none")]
198    pub reasoning_content: Option<String>,
199    #[serde(skip_serializing_if = "Option::is_none")]
200    pub thinking_blocks: Option<Vec<serde_json::Value>>,
201}
202
203impl Message {
204    /// Create a user message
205    pub fn user(content: impl Into<String>) -> Self {
206        Self {
207            role: "user".to_string(),
208            content: content.into(),
209            name: None,
210            tool_call_id: None,
211            tool_calls: None,
212            reasoning_content: None,
213            thinking_blocks: None,
214        }
215    }
216
217    /// Create a system message
218    pub fn system(content: impl Into<String>) -> Self {
219        Self {
220            role: "system".to_string(),
221            content: content.into(),
222            name: None,
223            tool_call_id: None,
224            tool_calls: None,
225            reasoning_content: None,
226            thinking_blocks: None,
227        }
228    }
229
230    /// Create an assistant message
231    pub fn assistant(content: impl Into<String>) -> Self {
232        Self {
233            role: "assistant".to_string(),
234            content: content.into(),
235            name: None,
236            tool_call_id: None,
237            tool_calls: None,
238            reasoning_content: None,
239            thinking_blocks: None,
240        }
241    }
242
243    /// Create a tool response message
244    pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
245        Self {
246            role: "tool".to_string(),
247            content: content.into(),
248            name: None,
249            tool_call_id: Some(tool_call_id.into()),
250            tool_calls: None,
251            reasoning_content: None,
252            thinking_blocks: None,
253        }
254    }
255}
256
257/// Trait for LLM providers
258#[async_trait]
259pub trait LLMProvider: Send + Sync {
260    /// Send a chat completion request
261    async fn chat(
262        &self,
263        messages: Vec<Message>,
264        tools: Option<Vec<serde_json::Value>>,
265        model: Option<String>,
266        max_tokens: i32,
267        temperature: f64,
268    ) -> ProviderResult<LLMResponse>;
269
270    /// Send a streaming chat completion request.
271    ///
272    /// Default behavior falls back to non-streaming chat and emits one text delta.
273    async fn chat_stream(
274        &self,
275        messages: Vec<Message>,
276        tools: Option<Vec<serde_json::Value>>,
277        model: Option<String>,
278        max_tokens: i32,
279        temperature: f64,
280    ) -> ProviderResult<ProviderEventStream> {
281        let response = self
282            .chat(messages, tools, model, max_tokens, temperature)
283            .await?;
284
285        let mut events = Vec::new();
286        if let Some(content) = response.content.clone() {
287            if !content.is_empty() {
288                events.push(Ok(LLMStreamEvent::TextDelta(content)));
289            }
290        }
291        events.push(Ok(LLMStreamEvent::Completed(response)));
292
293        Ok(Box::pin(stream::iter(events)))
294    }
295
296    /// Get the default model for this provider
297    fn get_default_model(&self) -> String;
298}