1use anyhow::Result;
2use async_trait::async_trait;
3use chrono::Utc;
4use serde::Serialize;
5use serde_json::{json, Value};
6
7use super::api_client::{ApiClient, AuthMethod};
8use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
9use super::errors::ProviderError;
10use super::retry::ProviderRetry;
11use super::utils::map_http_error_to_provider_error;
12use crate::conversation::message::{Message, MessageContent};
13
14use crate::mcp_utils::ToolResult;
15use crate::model::ModelConfig;
16use rmcp::model::{object, CallToolRequestParam, Role, Tool};
17
18#[derive(Debug)]
20struct CapabilityFlags(String);
21
22impl CapabilityFlags {
23 fn from_json(value: &serde_json::Value) -> Self {
24 let caps = &value["model_spec"]["capabilities"];
25 let mut s = String::with_capacity(6);
26 macro_rules! flag {
27 ($json_key:literal, $letter:literal) => {
28 if caps
29 .get($json_key)
30 .and_then(|v| v.as_bool())
31 .unwrap_or(false)
32 {
33 s.push($letter);
34 }
35 };
36 }
37 flag!("optimizedForCode", 'c'); flag!("supportsVision", 'v'); flag!("supportsFunctionCalling", 'f');
40 flag!("supportsResponseSchema", 's');
41 flag!("supportsWebSearch", 'w');
42 flag!("supportsReasoning", 'r');
43 CapabilityFlags(s)
44 }
45}
46
47impl std::fmt::Display for CapabilityFlags {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 write!(f, "[{}]", self.0) }
51}
52fn strip_flags(model: &str) -> &str {
57 model.split_whitespace().next().unwrap_or(model)
58}
59pub const VENICE_DOC_URL: &str = "https://docs.venice.ai/";
62pub const VENICE_DEFAULT_MODEL: &str = "llama-3.3-70b";
63pub const VENICE_DEFAULT_HOST: &str = "https://api.venice.ai";
64pub const VENICE_DEFAULT_BASE_PATH: &str = "api/v1/chat/completions";
65pub const VENICE_DEFAULT_MODELS_PATH: &str = "api/v1/models";
66
67const FALLBACK_MODELS: [&str; 3] = [
69 "llama-3.2-3b", "llama-3.3-70b", "mistral-31-24b", ];
73
74#[derive(Debug, Serialize)]
75pub struct VeniceProvider {
76 #[serde(skip)]
77 api_client: ApiClient,
78 base_path: String,
79 models_path: String,
80 model: ModelConfig,
81 #[serde(skip)]
82 name: String,
83}
84
85impl VeniceProvider {
86 pub async fn from_env(mut model: ModelConfig) -> Result<Self> {
87 let config = crate::config::Config::global();
88 let api_key: String = config.get_secret("VENICE_API_KEY")?;
89 let host: String = config
90 .get_param("VENICE_HOST")
91 .unwrap_or_else(|_| VENICE_DEFAULT_HOST.to_string());
92 let base_path: String = config
93 .get_param("VENICE_BASE_PATH")
94 .unwrap_or_else(|_| VENICE_DEFAULT_BASE_PATH.to_string());
95 let models_path: String = config
96 .get_param("VENICE_MODELS_PATH")
97 .unwrap_or_else(|_| VENICE_DEFAULT_MODELS_PATH.to_string());
98
99 model.model_name = strip_flags(&model.model_name).to_string();
101
102 let auth = AuthMethod::BearerToken(api_key);
103 let api_client = ApiClient::new(host, auth)?;
104
105 let instance = Self {
106 api_client,
107 base_path,
108 models_path,
109 model,
110 name: Self::metadata().name,
111 };
112
113 Ok(instance)
114 }
115
116 async fn post(&self, path: &str, payload: &Value) -> Result<Value, ProviderError> {
117 let response = self.api_client.response_post(path, payload).await?;
118
119 let status = response.status();
120 tracing::debug!("Venice response status: {}", status);
121
122 if !status.is_success() {
123 let error_body = response.text().await.unwrap_or_default();
125
126 tracing::debug!("Full Venice error response: {}", error_body);
128
129 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&error_body) {
131 println!(
133 "Venice API error response: {}",
134 serde_json::to_string_pretty(&json).unwrap_or_else(|_| json.to_string())
135 );
136
137 if let Some(details) = json.get("details") {
139 if let Some(tools) = details.get("tools") {
141 if let Some(errors) = tools.get("_errors") {
142 if errors.to_string().contains("not supported by this model") {
143 let model_name = self.model.model_name.clone();
144 return Err(ProviderError::RequestFailed(
145 format!("The selected model '{}' does not support tool calls. Please select a model that supports tools, such as 'llama-3.3-70b' or 'mistral-31-24b'.", model_name)
146 ));
147 }
148 }
149 }
150 }
151
152 if let Some(context) = json.get("context") {
154 if let Some(issues) = context.get("issues") {
155 if let Some(issues_array) = issues.as_array() {
156 for issue in issues_array {
157 if let Some(message) = issue.get("message").and_then(|m| m.as_str())
158 {
159 if message.contains("tools is not supported by this model") {
160 let model_name = self.model.model_name.clone();
161 return Err(ProviderError::RequestFailed(
162 format!("The selected model '{}' does not support tool calls. Please select a model that supports tools, such as 'llama-3.3-70b' or 'mistral-31-24b'.", model_name)
163 ));
164 }
165 }
166 }
167 }
168 }
169 }
170 }
171
172 let error_json = serde_json::from_str::<Value>(&error_body).ok();
174 return Err(map_http_error_to_provider_error(status, error_json));
175 }
176
177 let response_text = response.text().await?;
178 serde_json::from_str(&response_text).map_err(|e| {
179 ProviderError::RequestFailed(format!(
180 "Failed to parse JSON: {}\nResponse: {}",
181 e, response_text
182 ))
183 })
184 }
185}
186
187#[async_trait]
188impl Provider for VeniceProvider {
189 fn metadata() -> ProviderMetadata {
190 ProviderMetadata::new(
191 "venice",
192 "Venice.ai",
193 "Venice.ai models (Llama, DeepSeek, Mistral) with function calling",
194 VENICE_DEFAULT_MODEL,
195 FALLBACK_MODELS.to_vec(),
196 VENICE_DOC_URL,
197 vec![
198 ConfigKey::new("VENICE_API_KEY", true, true, None),
199 ConfigKey::new("VENICE_HOST", true, false, Some(VENICE_DEFAULT_HOST)),
200 ConfigKey::new(
201 "VENICE_BASE_PATH",
202 true,
203 false,
204 Some(VENICE_DEFAULT_BASE_PATH),
205 ),
206 ConfigKey::new(
207 "VENICE_MODELS_PATH",
208 true,
209 false,
210 Some(VENICE_DEFAULT_MODELS_PATH),
211 ),
212 ],
213 )
214 }
215
216 fn get_name(&self) -> &str {
217 &self.name
218 }
219
220 fn get_model_config(&self) -> ModelConfig {
221 self.model.clone()
222 }
223
224 async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
225 let response = self.api_client.response_get(&self.models_path).await?;
226 let json: serde_json::Value = response.json().await?;
227
228 let mut models = json["data"]
229 .as_array()
230 .ok_or_else(|| ProviderError::RequestFailed("No data field in JSON".to_string()))?
231 .iter()
232 .filter_map(|model| {
233 let id = model["id"].as_str()?.to_owned();
234 let flags = CapabilityFlags::from_json(model);
236 if flags.0.contains('f') {
238 Some(format!("{id} {flags}"))
239 } else {
240 None
241 }
242 })
243 .collect::<Vec<String>>();
244 models.sort();
245 Ok(Some(models))
246 }
247
248 #[tracing::instrument(
249 skip(self, model_config, system, messages, tools),
250 fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
251 )]
252 async fn complete_with_model(
253 &self,
254 model_config: &ModelConfig,
255 system: &str,
256 messages: &[Message],
257 tools: &[Tool],
258 ) -> Result<(Message, ProviderUsage), ProviderError> {
259 let mut formatted_messages = Vec::new();
261
262 if !system.is_empty() {
264 formatted_messages.push(json!({
265 "role": "system",
266 "content": system
267 }));
268 }
269
270 for msg in messages {
272 let content = match msg.role {
274 Role::User => {
275 let text_content: String = msg
277 .content
278 .iter()
279 .filter_map(|c| c.as_text())
280 .collect::<Vec<_>>()
281 .join("\n");
282
283 if !text_content.is_empty() {
285 text_content
286 } else {
287 msg.as_concat_text()
289 }
290 }
291 _ => {
292 let has_tool_calls = msg
294 .content
295 .iter()
296 .any(|c| matches!(c, MessageContent::ToolRequest(_)));
297
298 if has_tool_calls {
299 "".to_string()
302 } else {
303 msg.as_concat_text()
305 }
306 }
307 };
308
309 let mut venice_msg = json!({
311 "role": match msg.role {
312 Role::User => "user",
313 Role::Assistant => "assistant",
314 },
315 "content": content
316 });
317
318 tracing::debug!(
320 "Venice message format: role={:?}, content_len={}, has_tool_calls={}",
321 msg.role,
322 content.len(),
323 msg.content
324 .iter()
325 .any(|c| matches!(c, MessageContent::ToolRequest(_)))
326 );
327
328 if msg.role == Role::Assistant {
330 let tool_calls: Vec<_> = msg
331 .content
332 .iter()
333 .filter_map(|c| c.as_tool_request())
334 .collect();
335
336 if !tool_calls.is_empty() {
337 let venice_tool_calls: Vec<Value> = tool_calls
339 .iter()
340 .filter_map(|tr| {
341 if let ToolResult::Ok(tool_call) = &tr.tool_call {
342 let args_str = tool_call
344 .arguments
345 .as_ref() .map(|map| serde_json::to_string(map).unwrap_or_default())
347 .unwrap_or_default();
348
349 tracing::debug!(
351 "Tool call conversion: id={}, name={}, args_len={}",
352 tr.id,
353 tool_call.name,
354 args_str.len()
355 );
356
357 Some(json!({
359 "id": tr.id,
360 "type": "function",
361 "function": {
362 "name": tool_call.name,
363 "arguments": args_str
364 }
365 }))
366 } else {
367 tracing::warn!("Skipping tool call with error: id={}", tr.id);
368 None
369 }
370 })
371 .collect();
372
373 if !venice_tool_calls.is_empty() {
374 tracing::debug!("Adding {} tool calls to message", venice_tool_calls.len());
375 venice_msg["tool_calls"] = json!(venice_tool_calls);
376 }
377 }
378 }
379
380 {
384 let tool_responses: Vec<_> = msg
385 .content
386 .iter()
387 .filter_map(|c| c.as_tool_response())
388 .collect();
389
390 if !tool_responses.is_empty() && !tool_responses[0].id.is_empty() {
391 venice_msg["tool_call_id"] = json!(tool_responses[0].id);
392 venice_msg["role"] = json!("tool");
394 }
395 }
396
397 formatted_messages.push(venice_msg);
398 }
399
400 let mut payload = json!({
402 "model": strip_flags(&model_config.model_name),
403 "messages": formatted_messages,
404 "stream": false,
405 "temperature": 0.7,
406 "max_tokens": 2048,
407 });
408
409 if !tools.is_empty() {
410 let formatted_tools: Vec<serde_json::Value> = tools
412 .iter()
413 .map(|tool| {
414 json!({
416 "type": "function",
417 "function": {
418 "name": tool.name,
419 "description": tool.description,
420 "parameters": tool.input_schema
421 }
422 })
423 })
424 .collect();
425
426 payload["tools"] = json!(formatted_tools);
427 }
428
429 tracing::debug!("Sending request to Venice API");
430 tracing::debug!("Venice request payload: {}", payload.to_string());
431
432 let response = self
434 .with_retry(|| self.post(&self.base_path, &payload))
435 .await?;
436
437 let response_json = response;
439
440 let tool_calls = response_json["choices"]
442 .get(0)
443 .and_then(|choice| choice["message"]["tool_calls"].as_array());
444
445 if let Some(tool_calls) = tool_calls {
446 if !tool_calls.is_empty() {
447 let mut content = Vec::new();
449
450 for tool_call in tool_calls {
451 let id = tool_call["id"].as_str().unwrap_or("unknown").to_string();
452 let function = tool_call["function"].clone();
453 let name = function["name"].as_str().unwrap_or("unknown").to_string();
454
455 let arguments = if let Some(args_str) = function["arguments"].as_str() {
457 serde_json::from_str::<Value>(args_str)
458 .unwrap_or(function["arguments"].clone())
459 } else {
460 function["arguments"].clone()
461 };
462
463 let tool_call = CallToolRequestParam {
464 name: name.into(),
465 arguments: Some(object(arguments)),
466 };
467
468 let tool_request = MessageContent::tool_request(id, ToolResult::Ok(tool_call));
470
471 content.push(tool_request);
472 }
473
474 let mut message = Message::assistant();
476 for item in content {
477 message = message.with_content(item);
478 }
479
480 return Ok((
481 message,
482 ProviderUsage::new(
483 strip_flags(&model_config.model_name).to_string(),
484 Usage::default(),
485 ),
486 ));
487 }
488 }
489
490 let content = response_json["choices"]
493 .get(0)
494 .and_then(|choice| choice["message"]["content"].as_str())
495 .ok_or_else(|| {
496 tracing::error!("Invalid response format: {:?}", response_json);
497 ProviderError::RequestFailed("Invalid response format: missing content".to_string())
498 })?
499 .to_string();
500
501 let content = vec![MessageContent::text(content)];
503
504 let usage_data = &response_json["usage"];
506 let usage = Usage::new(
507 usage_data["prompt_tokens"].as_i64().map(|v| v as i32),
508 usage_data["completion_tokens"].as_i64().map(|v| v as i32),
509 usage_data["total_tokens"].as_i64().map(|v| v as i32),
510 );
511
512 Ok((
513 Message::new(Role::Assistant, Utc::now().timestamp(), content),
514 ProviderUsage::new(strip_flags(&self.model.model_name).to_string(), usage),
515 ))
516 }
517}
518
519#[cfg(test)]
520mod tests {
521 use super::*;
522
523 #[test]
524 fn test_metadata_structure() {
525 let metadata = VeniceProvider::metadata();
526
527 assert_eq!(metadata.default_model, "llama-3.3-70b");
528 assert!(!metadata.known_models.is_empty());
529
530 assert_eq!(metadata.config_keys.len(), 4);
531 assert_eq!(metadata.config_keys[0].name, "VENICE_API_KEY");
532 assert_eq!(metadata.config_keys[1].name, "VENICE_HOST");
533 assert_eq!(metadata.config_keys[2].name, "VENICE_BASE_PATH");
534 assert_eq!(metadata.config_keys[3].name, "VENICE_MODELS_PATH");
535 }
536}