Skip to main content

aster/providers/
venice.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use chrono::Utc;
4use serde::Serialize;
5use serde_json::{json, Value};
6
7use super::api_client::{ApiClient, AuthMethod};
8use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
9use super::errors::ProviderError;
10use super::retry::ProviderRetry;
11use super::utils::map_http_error_to_provider_error;
12use crate::conversation::message::{Message, MessageContent};
13
14use crate::mcp_utils::ToolResult;
15use crate::model::ModelConfig;
16use rmcp::model::{object, CallToolRequestParam, Role, Tool};
17
18// ---------- Capability Flags ----------
19#[derive(Debug)]
20struct CapabilityFlags(String);
21
22impl CapabilityFlags {
23    fn from_json(value: &serde_json::Value) -> Self {
24        let caps = &value["model_spec"]["capabilities"];
25        let mut s = String::with_capacity(6);
26        macro_rules! flag {
27            ($json_key:literal, $letter:literal) => {
28                if caps
29                    .get($json_key)
30                    .and_then(|v| v.as_bool())
31                    .unwrap_or(false)
32                {
33                    s.push($letter);
34                }
35            };
36        }
37        flag!("optimizedForCode", 'c'); // code
38        flag!("supportsVision", 'v'); // vision
39        flag!("supportsFunctionCalling", 'f');
40        flag!("supportsResponseSchema", 's');
41        flag!("supportsWebSearch", 'w');
42        flag!("supportsReasoning", 'r');
43        CapabilityFlags(s)
44    }
45}
46
47impl std::fmt::Display for CapabilityFlags {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        write!(f, "[{}]", self.0) // e.g. "[cvfsw]"
50    }
51}
52// ---------- END Capability Flags ----------
53
54// ---------- Helpers ----------
55/// Return the raw model id (everything before the first space).
56fn strip_flags(model: &str) -> &str {
57    model.split_whitespace().next().unwrap_or(model)
58}
59// ---------- END Helpers ----------
60
61pub const VENICE_DOC_URL: &str = "https://docs.venice.ai/";
62pub const VENICE_DEFAULT_MODEL: &str = "llama-3.3-70b";
63pub const VENICE_DEFAULT_HOST: &str = "https://api.venice.ai";
64pub const VENICE_DEFAULT_BASE_PATH: &str = "api/v1/chat/completions";
65pub const VENICE_DEFAULT_MODELS_PATH: &str = "api/v1/models";
66
67// Fallback models to use when API is unavailable
68const FALLBACK_MODELS: [&str; 3] = [
69    "llama-3.2-3b",   // Small model with function calling
70    "llama-3.3-70b",  // Default model with function calling
71    "mistral-31-24b", // Another model with function calling
72];
73
74#[derive(Debug, Serialize)]
75pub struct VeniceProvider {
76    #[serde(skip)]
77    api_client: ApiClient,
78    base_path: String,
79    models_path: String,
80    model: ModelConfig,
81    #[serde(skip)]
82    name: String,
83}
84
85impl VeniceProvider {
86    pub async fn from_env(mut model: ModelConfig) -> Result<Self> {
87        let config = crate::config::Config::global();
88        let api_key: String = config.get_secret("VENICE_API_KEY")?;
89        let host: String = config
90            .get_param("VENICE_HOST")
91            .unwrap_or_else(|_| VENICE_DEFAULT_HOST.to_string());
92        let base_path: String = config
93            .get_param("VENICE_BASE_PATH")
94            .unwrap_or_else(|_| VENICE_DEFAULT_BASE_PATH.to_string());
95        let models_path: String = config
96            .get_param("VENICE_MODELS_PATH")
97            .unwrap_or_else(|_| VENICE_DEFAULT_MODELS_PATH.to_string());
98
99        // Ensure we only keep the bare model id internally
100        model.model_name = strip_flags(&model.model_name).to_string();
101
102        let auth = AuthMethod::BearerToken(api_key);
103        let api_client = ApiClient::new(host, auth)?;
104
105        let instance = Self {
106            api_client,
107            base_path,
108            models_path,
109            model,
110            name: Self::metadata().name,
111        };
112
113        Ok(instance)
114    }
115
116    async fn post(&self, path: &str, payload: &Value) -> Result<Value, ProviderError> {
117        let response = self.api_client.response_post(path, payload).await?;
118
119        let status = response.status();
120        tracing::debug!("Venice response status: {}", status);
121
122        if !status.is_success() {
123            // Read response body for more details on error
124            let error_body = response.text().await.unwrap_or_default();
125
126            // Log full error response for debugging
127            tracing::debug!("Full Venice error response: {}", error_body);
128
129            // Try to parse the error response
130            if let Ok(json) = serde_json::from_str::<serde_json::Value>(&error_body) {
131                // Print the full JSON error for better debugging
132                println!(
133                    "Venice API error response: {}",
134                    serde_json::to_string_pretty(&json).unwrap_or_else(|_| json.to_string())
135                );
136
137                // Check for tool support errors
138                if let Some(details) = json.get("details") {
139                    // Specifically look for tool support issues
140                    if let Some(tools) = details.get("tools") {
141                        if let Some(errors) = tools.get("_errors") {
142                            if errors.to_string().contains("not supported by this model") {
143                                let model_name = self.model.model_name.clone();
144                                return Err(ProviderError::RequestFailed(
145                                    format!("The selected model '{}' does not support tool calls. Please select a model that supports tools, such as 'llama-3.3-70b' or 'mistral-31-24b'.", model_name)
146                                ));
147                            }
148                        }
149                    }
150                }
151
152                // Check for specific error message in context.issues
153                if let Some(context) = json.get("context") {
154                    if let Some(issues) = context.get("issues") {
155                        if let Some(issues_array) = issues.as_array() {
156                            for issue in issues_array {
157                                if let Some(message) = issue.get("message").and_then(|m| m.as_str())
158                                {
159                                    if message.contains("tools is not supported by this model") {
160                                        let model_name = self.model.model_name.clone();
161                                        return Err(ProviderError::RequestFailed(
162                                            format!("The selected model '{}' does not support tool calls. Please select a model that supports tools, such as 'llama-3.3-70b' or 'mistral-31-24b'.", model_name)
163                                        ));
164                                    }
165                                }
166                            }
167                        }
168                    }
169                }
170            }
171
172            // Use the common error mapping function
173            let error_json = serde_json::from_str::<Value>(&error_body).ok();
174            return Err(map_http_error_to_provider_error(status, error_json));
175        }
176
177        let response_text = response.text().await?;
178        serde_json::from_str(&response_text).map_err(|e| {
179            ProviderError::RequestFailed(format!(
180                "Failed to parse JSON: {}\nResponse: {}",
181                e, response_text
182            ))
183        })
184    }
185}
186
187#[async_trait]
188impl Provider for VeniceProvider {
189    fn metadata() -> ProviderMetadata {
190        ProviderMetadata::new(
191            "venice",
192            "Venice.ai",
193            "Venice.ai models (Llama, DeepSeek, Mistral) with function calling",
194            VENICE_DEFAULT_MODEL,
195            FALLBACK_MODELS.to_vec(),
196            VENICE_DOC_URL,
197            vec![
198                ConfigKey::new("VENICE_API_KEY", true, true, None),
199                ConfigKey::new("VENICE_HOST", true, false, Some(VENICE_DEFAULT_HOST)),
200                ConfigKey::new(
201                    "VENICE_BASE_PATH",
202                    true,
203                    false,
204                    Some(VENICE_DEFAULT_BASE_PATH),
205                ),
206                ConfigKey::new(
207                    "VENICE_MODELS_PATH",
208                    true,
209                    false,
210                    Some(VENICE_DEFAULT_MODELS_PATH),
211                ),
212            ],
213        )
214    }
215
216    fn get_name(&self) -> &str {
217        &self.name
218    }
219
220    fn get_model_config(&self) -> ModelConfig {
221        self.model.clone()
222    }
223
224    async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
225        let response = self.api_client.response_get(&self.models_path).await?;
226        let json: serde_json::Value = response.json().await?;
227
228        let mut models = json["data"]
229            .as_array()
230            .ok_or_else(|| ProviderError::RequestFailed("No data field in JSON".to_string()))?
231            .iter()
232            .filter_map(|model| {
233                let id = model["id"].as_str()?.to_owned();
234                // Build flags from capabilities
235                let flags = CapabilityFlags::from_json(model);
236                // Only include models that support function calling (have 'f' flag)
237                if flags.0.contains('f') {
238                    Some(format!("{id} {flags}"))
239                } else {
240                    None
241                }
242            })
243            .collect::<Vec<String>>();
244        models.sort();
245        Ok(Some(models))
246    }
247
248    #[tracing::instrument(
249        skip(self, model_config, system, messages, tools),
250        fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
251    )]
252    async fn complete_with_model(
253        &self,
254        model_config: &ModelConfig,
255        system: &str,
256        messages: &[Message],
257        tools: &[Tool],
258    ) -> Result<(Message, ProviderUsage), ProviderError> {
259        // Create properly formatted messages for Venice API
260        let mut formatted_messages = Vec::new();
261
262        // Add the system message if present
263        if !system.is_empty() {
264            formatted_messages.push(json!({
265                "role": "system",
266                "content": system
267            }));
268        }
269
270        // Format regular messages according to Venice API requirements
271        for msg in messages {
272            // Venice API expects 'content' to be a string, not an array of MessageContent
273            let content = match msg.role {
274                Role::User => {
275                    // For user messages, concatenate all text content
276                    let text_content: String = msg
277                        .content
278                        .iter()
279                        .filter_map(|c| c.as_text())
280                        .collect::<Vec<_>>()
281                        .join("\n");
282
283                    // If we have text content, use it directly
284                    if !text_content.is_empty() {
285                        text_content
286                    } else {
287                        // Otherwise, try to get a reasonable string representation
288                        msg.as_concat_text()
289                    }
290                }
291                _ => {
292                    // For assistant messages, handle possible tool calls
293                    let has_tool_calls = msg
294                        .content
295                        .iter()
296                        .any(|c| matches!(c, MessageContent::ToolRequest(_)));
297
298                    if has_tool_calls {
299                        // If there are tool calls, we'll handle them separately
300                        // Just use an empty string for content
301                        "".to_string()
302                    } else {
303                        // Otherwise use text content
304                        msg.as_concat_text()
305                    }
306                }
307            };
308
309            // Create basic message with content as string
310            let mut venice_msg = json!({
311                "role": match msg.role {
312                    Role::User => "user",
313                    Role::Assistant => "assistant",
314                },
315                "content": content
316            });
317
318            // Add debug information to tracing
319            tracing::debug!(
320                "Venice message format: role={:?}, content_len={}, has_tool_calls={}",
321                msg.role,
322                content.len(),
323                msg.content
324                    .iter()
325                    .any(|c| matches!(c, MessageContent::ToolRequest(_)))
326            );
327
328            // For assistant messages with tool calls, add them in Venice format
329            if msg.role == Role::Assistant {
330                let tool_calls: Vec<_> = msg
331                    .content
332                    .iter()
333                    .filter_map(|c| c.as_tool_request())
334                    .collect();
335
336                if !tool_calls.is_empty() {
337                    // Transform our tool calls to Venice format
338                    let venice_tool_calls: Vec<Value> = tool_calls
339                        .iter()
340                        .filter_map(|tr| {
341                            if let ToolResult::Ok(tool_call) = &tr.tool_call {
342                                // Safely convert arguments to a JSON string
343                                let args_str = tool_call
344                                    .arguments
345                                    .as_ref() // borrow the Option contents
346                                    .map(|map| serde_json::to_string(map).unwrap_or_default())
347                                    .unwrap_or_default();
348
349                                // Log tool call details for debugging
350                                tracing::debug!(
351                                    "Tool call conversion: id={}, name={}, args_len={}",
352                                    tr.id,
353                                    tool_call.name,
354                                    args_str.len()
355                                );
356
357                                // Convert to Venice format
358                                Some(json!({
359                                    "id": tr.id,
360                                    "type": "function",
361                                    "function": {
362                                        "name": tool_call.name,
363                                        "arguments": args_str
364                                    }
365                                }))
366                            } else {
367                                tracing::warn!("Skipping tool call with error: id={}", tr.id);
368                                None
369                            }
370                        })
371                        .collect();
372
373                    if !venice_tool_calls.is_empty() {
374                        tracing::debug!("Adding {} tool calls to message", venice_tool_calls.len());
375                        venice_msg["tool_calls"] = json!(venice_tool_calls);
376                    }
377                }
378            }
379
380            // For tool messages with tool responses, add required tool_call_id
381            // Check for tool responses regardless of role - they should have an ID
382            // that corresponds to the tool call they're responding to
383            {
384                let tool_responses: Vec<_> = msg
385                    .content
386                    .iter()
387                    .filter_map(|c| c.as_tool_response())
388                    .collect();
389
390                if !tool_responses.is_empty() && !tool_responses[0].id.is_empty() {
391                    venice_msg["tool_call_id"] = json!(tool_responses[0].id);
392                    // Venice expects tool messages to have 'role' = 'tool'
393                    venice_msg["role"] = json!("tool");
394                }
395            }
396
397            formatted_messages.push(venice_msg);
398        }
399
400        // Build Venice-specific payload
401        let mut payload = json!({
402            "model": strip_flags(&model_config.model_name),
403            "messages": formatted_messages,
404            "stream": false,
405            "temperature": 0.7,
406            "max_tokens": 2048,
407        });
408
409        if !tools.is_empty() {
410            // Format tools specifically for Venice API
411            let formatted_tools: Vec<serde_json::Value> = tools
412                .iter()
413                .map(|tool| {
414                    // Format each tool in the expected Venice format
415                    json!({
416                        "type": "function",
417                        "function": {
418                            "name": tool.name,
419                            "description": tool.description,
420                            "parameters": tool.input_schema
421                        }
422                    })
423                })
424                .collect();
425
426            payload["tools"] = json!(formatted_tools);
427        }
428
429        tracing::debug!("Sending request to Venice API");
430        tracing::debug!("Venice request payload: {}", payload.to_string());
431
432        // Send request with retry
433        let response = self
434            .with_retry(|| self.post(&self.base_path, &payload))
435            .await?;
436
437        // Parse the response - response is already a Value from our post method
438        let response_json = response;
439
440        // Handle tool calls from the response if present
441        let tool_calls = response_json["choices"]
442            .get(0)
443            .and_then(|choice| choice["message"]["tool_calls"].as_array());
444
445        if let Some(tool_calls) = tool_calls {
446            if !tool_calls.is_empty() {
447                // Extract tool calls and format for our internal model
448                let mut content = Vec::new();
449
450                for tool_call in tool_calls {
451                    let id = tool_call["id"].as_str().unwrap_or("unknown").to_string();
452                    let function = tool_call["function"].clone();
453                    let name = function["name"].as_str().unwrap_or("unknown").to_string();
454
455                    // Parse arguments string to Value if it's a string
456                    let arguments = if let Some(args_str) = function["arguments"].as_str() {
457                        serde_json::from_str::<Value>(args_str)
458                            .unwrap_or(function["arguments"].clone())
459                    } else {
460                        function["arguments"].clone()
461                    };
462
463                    let tool_call = CallToolRequestParam {
464                        name: name.into(),
465                        arguments: Some(object(arguments)),
466                    };
467
468                    // Create a ToolRequest MessageContent
469                    let tool_request = MessageContent::tool_request(id, ToolResult::Ok(tool_call));
470
471                    content.push(tool_request);
472                }
473
474                // Create message and add each content item
475                let mut message = Message::assistant();
476                for item in content {
477                    message = message.with_content(item);
478                }
479
480                return Ok((
481                    message,
482                    ProviderUsage::new(
483                        strip_flags(&model_config.model_name).to_string(),
484                        Usage::default(),
485                    ),
486                ));
487            }
488        }
489
490        // If we get here, it's a regular text response
491        // Extract content
492        let content = response_json["choices"]
493            .get(0)
494            .and_then(|choice| choice["message"]["content"].as_str())
495            .ok_or_else(|| {
496                tracing::error!("Invalid response format: {:?}", response_json);
497                ProviderError::RequestFailed("Invalid response format: missing content".to_string())
498            })?
499            .to_string();
500
501        // Create a vector with a single text content item
502        let content = vec![MessageContent::text(content)];
503
504        // Extract usage
505        let usage_data = &response_json["usage"];
506        let usage = Usage::new(
507            usage_data["prompt_tokens"].as_i64().map(|v| v as i32),
508            usage_data["completion_tokens"].as_i64().map(|v| v as i32),
509            usage_data["total_tokens"].as_i64().map(|v| v as i32),
510        );
511
512        Ok((
513            Message::new(Role::Assistant, Utc::now().timestamp(), content),
514            ProviderUsage::new(strip_flags(&self.model.model_name).to_string(), usage),
515        ))
516    }
517}
518
519#[cfg(test)]
520mod tests {
521    use super::*;
522
523    #[test]
524    fn test_metadata_structure() {
525        let metadata = VeniceProvider::metadata();
526
527        assert_eq!(metadata.default_model, "llama-3.3-70b");
528        assert!(!metadata.known_models.is_empty());
529
530        assert_eq!(metadata.config_keys.len(), 4);
531        assert_eq!(metadata.config_keys[0].name, "VENICE_API_KEY");
532        assert_eq!(metadata.config_keys[1].name, "VENICE_HOST");
533        assert_eq!(metadata.config_keys[2].name, "VENICE_BASE_PATH");
534        assert_eq!(metadata.config_keys[3].name, "VENICE_MODELS_PATH");
535    }
536}