rig/providers/
mira.rs

1//! Mira API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::mira;
6//!
7//! let client = mira::Client::new("YOUR_API_KEY");
8//!
9//! ```
10use crate::json_utils::merge;
11use crate::providers::openai::send_compatible_streaming_request;
12use crate::streaming::{StreamingCompletionModel, StreamingResult};
13use crate::{
14    agent::AgentBuilder,
15    completion::{self, CompletionError, CompletionRequest},
16    extractor::ExtractorBuilder,
17    message::{self, AssistantContent, Message, UserContent},
18    OneOrMany,
19};
20use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
21use schemars::JsonSchema;
22use serde::{Deserialize, Serialize};
23use serde_json::{json, Value};
24use std::string::FromUtf8Error;
25use thiserror::Error;
26use tracing;
27
28#[derive(Debug, Error)]
29pub enum MiraError {
30    #[error("Invalid API key")]
31    InvalidApiKey,
32    #[error("API error: {0}")]
33    ApiError(u16),
34    #[error("Request error: {0}")]
35    RequestError(#[from] reqwest::Error),
36    #[error("UTF-8 error: {0}")]
37    Utf8Error(#[from] FromUtf8Error),
38    #[error("JSON error: {0}")]
39    JsonError(#[from] serde_json::Error),
40}
41
42#[derive(Debug, Deserialize)]
43struct ApiErrorResponse {
44    message: String,
45}
46
47#[derive(Debug, Deserialize, Clone)]
48pub struct RawMessage {
49    pub role: String,
50    pub content: String,
51}
52
53const MIRA_API_BASE_URL: &str = "https://api.mira.network";
54
55impl TryFrom<RawMessage> for message::Message {
56    type Error = CompletionError;
57
58    fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
59        match raw.role.as_str() {
60            "user" => Ok(message::Message::User {
61                content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
62            }),
63            "assistant" => Ok(message::Message::Assistant {
64                content: OneOrMany::one(AssistantContent::Text(message::Text {
65                    text: raw.content,
66                })),
67            }),
68            _ => Err(CompletionError::ResponseError(format!(
69                "Unsupported message role: {}",
70                raw.role
71            ))),
72        }
73    }
74}
75
76#[derive(Debug, Deserialize)]
77#[serde(untagged)]
78pub enum CompletionResponse {
79    Structured {
80        id: String,
81        object: String,
82        created: u64,
83        model: String,
84        choices: Vec<ChatChoice>,
85        #[serde(skip_serializing_if = "Option::is_none")]
86        usage: Option<Usage>,
87    },
88    Simple(String),
89}
90
91#[derive(Debug, Deserialize)]
92pub struct ChatChoice {
93    pub message: RawMessage,
94    #[serde(default)]
95    pub finish_reason: Option<String>,
96    #[serde(default)]
97    pub index: Option<usize>,
98}
99
100#[derive(Debug, Deserialize)]
101struct ModelsResponse {
102    data: Vec<ModelInfo>,
103}
104
105#[derive(Debug, Deserialize)]
106struct ModelInfo {
107    id: String,
108}
109
110#[derive(Clone)]
111/// Client for interacting with the Mira API
112pub struct Client {
113    base_url: String,
114    client: reqwest::Client,
115    headers: HeaderMap,
116}
117
118impl Client {
119    /// Create a new Mira client with the given API key
120    pub fn new(api_key: &str) -> Result<Self, MiraError> {
121        let mut headers = HeaderMap::new();
122        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
123        headers.insert(
124            AUTHORIZATION,
125            HeaderValue::from_str(&format!("Bearer {}", api_key))
126                .map_err(|_| MiraError::InvalidApiKey)?,
127        );
128        headers.insert(
129            reqwest::header::ACCEPT,
130            HeaderValue::from_static("application/json"),
131        );
132        headers.insert(
133            reqwest::header::USER_AGENT,
134            HeaderValue::from_static("rig-client/1.0"),
135        );
136
137        Ok(Self {
138            base_url: MIRA_API_BASE_URL.to_string(),
139            client: reqwest::Client::builder()
140                .build()
141                .expect("Failed to build HTTP client"),
142            headers,
143        })
144    }
145
146    /// Create a new Mira client from the `MIRA_API_KEY` environment variable.
147    /// Panics if the environment variable is not set.
148    pub fn from_env() -> Result<Self, MiraError> {
149        let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
150        Self::new(&api_key)
151    }
152
153    /// Create a new Mira client with a custom base URL and API key
154    pub fn new_with_base_url(
155        api_key: &str,
156        base_url: impl Into<String>,
157    ) -> Result<Self, MiraError> {
158        let mut client = Self::new(api_key)?;
159        client.base_url = base_url.into();
160        Ok(client)
161    }
162
163    /// List available models
164    pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
165        let url = format!("{}/v1/models", self.base_url);
166
167        let response = self
168            .client
169            .get(&url)
170            .headers(self.headers.clone())
171            .send()
172            .await?;
173
174        let status = response.status();
175
176        if !status.is_success() {
177            // Log the error text but don't store it in an unused variable
178            let _error_text = response.text().await.unwrap_or_default();
179            tracing::error!("Error response: {}", _error_text);
180            return Err(MiraError::ApiError(status.as_u16()));
181        }
182
183        let response_text = response.text().await?;
184
185        let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
186            tracing::error!("Failed to parse response: {}", e);
187            MiraError::JsonError(e)
188        })?;
189
190        Ok(models.data.into_iter().map(|model| model.id).collect())
191    }
192
193    /// Create a completion model with the given name.
194    pub fn completion_model(&self, model: &str) -> CompletionModel {
195        CompletionModel::new(self.to_owned(), model)
196    }
197
198    /// Create an agent builder with the given completion model.
199    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
200        AgentBuilder::new(self.completion_model(model))
201    }
202
203    /// Create an extractor builder with the given completion model.
204    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
205        &self,
206        model: &str,
207    ) -> ExtractorBuilder<T, CompletionModel> {
208        ExtractorBuilder::new(self.completion_model(model))
209    }
210}
211
212#[derive(Clone)]
213pub struct CompletionModel {
214    client: Client,
215    /// Name of the model
216    pub model: String,
217}
218
219impl CompletionModel {
220    pub fn new(client: Client, model: &str) -> Self {
221        Self {
222            client,
223            model: model.to_string(),
224        }
225    }
226
227    fn create_completion_request(
228        &self,
229        completion_request: CompletionRequest,
230    ) -> Result<Value, CompletionError> {
231        let mut messages = Vec::new();
232
233        // Add preamble as user message if available
234        if let Some(preamble) = &completion_request.preamble {
235            messages.push(serde_json::json!({
236                "role": "user",
237                "content": preamble.to_string()
238            }));
239        }
240
241        // Add docs
242        if let Some(Message::User { content }) = completion_request.normalized_documents() {
243            let text = content
244                .into_iter()
245                .filter_map(|doc| match doc {
246                    UserContent::Document(doc) => Some(doc.data),
247                    UserContent::Text(text) => Some(text.text),
248
249                    // This should always be `Document`
250                    _ => None,
251                })
252                .collect::<Vec<_>>()
253                .join("\n");
254
255            messages.push(serde_json::json!({
256                "role": "user",
257                "content": text
258            }));
259        }
260
261        // Add chat history
262        for msg in completion_request.chat_history {
263            let (role, content) = match msg {
264                Message::User { content } => {
265                    let text = content
266                        .iter()
267                        .map(|c| match c {
268                            UserContent::Text(text) => &text.text,
269                            _ => "",
270                        })
271                        .collect::<Vec<_>>()
272                        .join("\n");
273                    ("user", text)
274                }
275                Message::Assistant { content } => {
276                    let text = content
277                        .iter()
278                        .map(|c| match c {
279                            AssistantContent::Text(text) => &text.text,
280                            _ => "",
281                        })
282                        .collect::<Vec<_>>()
283                        .join("\n");
284                    ("assistant", text)
285                }
286            };
287            messages.push(serde_json::json!({
288                "role": role,
289                "content": content
290            }));
291        }
292
293        let request = serde_json::json!({
294            "model": self.model,
295            "messages": messages,
296            "temperature": completion_request.temperature.map(|t| t as f32).unwrap_or(0.7),
297            "max_tokens": completion_request.max_tokens.map(|t| t as u32).unwrap_or(100),
298            "stream": false
299        });
300
301        Ok(request)
302    }
303}
304
305impl completion::CompletionModel for CompletionModel {
306    type Response = CompletionResponse;
307
308    #[cfg_attr(feature = "worker", worker::send)]
309    async fn completion(
310        &self,
311        completion_request: CompletionRequest,
312    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
313        if !completion_request.tools.is_empty() {
314            tracing::warn!(target: "rig",
315                "Tool calls are not supported by the Mira provider. {} tools will be ignored.",
316                completion_request.tools.len()
317            );
318        }
319
320        let mira_request = self.create_completion_request(completion_request)?;
321
322        let response = self
323            .client
324            .client
325            .post(format!("{}/v1/chat/completions", self.client.base_url))
326            .headers(self.client.headers.clone())
327            .json(&mira_request)
328            .send()
329            .await
330            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
331
332        if !response.status().is_success() {
333            let status = response.status().as_u16();
334            let error_text = response.text().await.unwrap_or_default();
335            return Err(CompletionError::ProviderError(format!(
336                "API error: {} - {}",
337                status, error_text
338            )));
339        }
340
341        let response: CompletionResponse = response
342            .json()
343            .await
344            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
345
346        response.try_into()
347    }
348}
349
350impl StreamingCompletionModel for CompletionModel {
351    async fn stream(
352        &self,
353        completion_request: CompletionRequest,
354    ) -> Result<StreamingResult, CompletionError> {
355        let mut request = self.create_completion_request(completion_request)?;
356
357        request = merge(request, json!({"stream": true}));
358
359        let builder = self
360            .client
361            .client
362            .post(format!("{}/v1/chat/completions", self.client.base_url))
363            .headers(self.client.headers.clone())
364            .json(&request);
365
366        send_compatible_streaming_request(builder).await
367    }
368}
369
370impl From<ApiErrorResponse> for CompletionError {
371    fn from(err: ApiErrorResponse) -> Self {
372        CompletionError::ProviderError(err.message)
373    }
374}
375
376impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
377    type Error = CompletionError;
378
379    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
380        let content = match &response {
381            CompletionResponse::Structured { choices, .. } => {
382                let choice = choices.first().ok_or_else(|| {
383                    CompletionError::ResponseError("Response contained no choices".to_owned())
384                })?;
385
386                // Convert RawMessage to message::Message
387                let message = message::Message::try_from(choice.message.clone())?;
388
389                match message {
390                    Message::Assistant { content } => {
391                        if content.is_empty() {
392                            return Err(CompletionError::ResponseError(
393                                "Response contained empty content".to_owned(),
394                            ));
395                        }
396
397                        // Log warning for unsupported content types
398                        for c in content.iter() {
399                            if !matches!(c, AssistantContent::Text(_)) {
400                                tracing::warn!(target: "rig",
401                                    "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
402                                );
403                            }
404                        }
405
406                        content.iter().map(|c| {
407                            match c {
408                                AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
409                                other => Err(CompletionError::ResponseError(
410                                    format!("Unsupported content type: {:?}. The Mira provider currently only supports text content", other)
411                                ))
412                            }
413                        }).collect::<Result<Vec<_>, _>>()?
414                    }
415                    Message::User { .. } => {
416                        tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
417                        return Err(CompletionError::ResponseError(
418                            "Received user message in response where assistant message was expected".to_owned()
419                        ));
420                    }
421                }
422            }
423            CompletionResponse::Simple(text) => {
424                vec![completion::AssistantContent::text(text)]
425            }
426        };
427
428        let choice = OneOrMany::many(content).map_err(|_| {
429            CompletionError::ResponseError(
430                "Response contained no message or tool call (empty)".to_owned(),
431            )
432        })?;
433
434        Ok(completion::CompletionResponse {
435            choice,
436            raw_response: response,
437        })
438    }
439}
440
441#[derive(Clone, Debug, Deserialize)]
442pub struct Usage {
443    pub prompt_tokens: usize,
444    pub total_tokens: usize,
445}
446
447impl std::fmt::Display for Usage {
448    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
449        write!(
450            f,
451            "Prompt tokens: {} Total tokens: {}",
452            self.prompt_tokens, self.total_tokens
453        )
454    }
455}
456
457impl From<Message> for serde_json::Value {
458    fn from(msg: Message) -> Self {
459        match msg {
460            Message::User { content } => {
461                let text = content
462                    .iter()
463                    .map(|c| match c {
464                        UserContent::Text(text) => &text.text,
465                        _ => "",
466                    })
467                    .collect::<Vec<_>>()
468                    .join("\n");
469                serde_json::json!({
470                    "role": "user",
471                    "content": text
472                })
473            }
474            Message::Assistant { content } => {
475                let text = content
476                    .iter()
477                    .map(|c| match c {
478                        AssistantContent::Text(text) => &text.text,
479                        _ => "",
480                    })
481                    .collect::<Vec<_>>()
482                    .join("\n");
483                serde_json::json!({
484                    "role": "assistant",
485                    "content": text
486                })
487            }
488        }
489    }
490}
491
492impl TryFrom<serde_json::Value> for Message {
493    type Error = CompletionError;
494
495    fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
496        let role = value["role"].as_str().ok_or_else(|| {
497            CompletionError::ResponseError("Message missing role field".to_owned())
498        })?;
499
500        // Handle both string and array content formats
501        let content = match value.get("content") {
502            Some(content) => match content {
503                serde_json::Value::String(s) => s.clone(),
504                serde_json::Value::Array(arr) => arr
505                    .iter()
506                    .filter_map(|c| {
507                        c.get("text")
508                            .and_then(|t| t.as_str())
509                            .map(|text| text.to_string())
510                    })
511                    .collect::<Vec<_>>()
512                    .join("\n"),
513                _ => {
514                    return Err(CompletionError::ResponseError(
515                        "Message content must be string or array".to_owned(),
516                    ))
517                }
518            },
519            None => {
520                return Err(CompletionError::ResponseError(
521                    "Message missing content field".to_owned(),
522                ))
523            }
524        };
525
526        match role {
527            "user" => Ok(Message::User {
528                content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
529            }),
530            "assistant" => Ok(Message::Assistant {
531                content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
532            }),
533            _ => Err(CompletionError::ResponseError(format!(
534                "Unsupported message role: {}",
535                role
536            ))),
537        }
538    }
539}
540
541#[cfg(test)]
542mod tests {
543    use super::*;
544    use crate::message::UserContent;
545    use serde_json::json;
546
547    #[test]
548    fn test_deserialize_message() {
549        // Test string content format
550        let assistant_message_json = json!({
551            "role": "assistant",
552            "content": "Hello there, how may I assist you today?"
553        });
554
555        let user_message_json = json!({
556            "role": "user",
557            "content": "What can you help me with?"
558        });
559
560        // Test array content format
561        let assistant_message_array_json = json!({
562            "role": "assistant",
563            "content": [{
564                "type": "text",
565                "text": "Hello there, how may I assist you today?"
566            }]
567        });
568
569        let assistant_message = Message::try_from(assistant_message_json).unwrap();
570        let user_message = Message::try_from(user_message_json).unwrap();
571        let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
572
573        // Test string content format
574        match assistant_message {
575            Message::Assistant { content } => {
576                assert_eq!(
577                    content.first(),
578                    AssistantContent::Text(message::Text {
579                        text: "Hello there, how may I assist you today?".to_string()
580                    })
581                );
582            }
583            _ => panic!("Expected assistant message"),
584        }
585
586        match user_message {
587            Message::User { content } => {
588                assert_eq!(
589                    content.first(),
590                    UserContent::Text(message::Text {
591                        text: "What can you help me with?".to_string()
592                    })
593                );
594            }
595            _ => panic!("Expected user message"),
596        }
597
598        // Test array content format
599        match assistant_message_array {
600            Message::Assistant { content } => {
601                assert_eq!(
602                    content.first(),
603                    AssistantContent::Text(message::Text {
604                        text: "Hello there, how may I assist you today?".to_string()
605                    })
606                );
607            }
608            _ => panic!("Expected assistant message"),
609        }
610    }
611
612    #[test]
613    fn test_message_conversion() {
614        // Test converting from our Message type to Mira's format and back
615        let original_message = message::Message::User {
616            content: OneOrMany::one(message::UserContent::text("Hello")),
617        };
618
619        // Convert to Mira format
620        let mira_value: serde_json::Value = original_message.clone().try_into().unwrap();
621
622        // Convert back to our Message type
623        let converted_message: Message = mira_value.try_into().unwrap();
624
625        // Convert back to original format
626        let final_message: message::Message = converted_message.try_into().unwrap();
627
628        assert_eq!(original_message, final_message);
629    }
630
631    #[test]
632    fn test_completion_response_conversion() {
633        let mira_response = CompletionResponse::Structured {
634            id: "resp_123".to_string(),
635            object: "chat.completion".to_string(),
636            created: 1234567890,
637            model: "deepseek-r1".to_string(),
638            choices: vec![ChatChoice {
639                message: RawMessage {
640                    role: "assistant".to_string(),
641                    content: "Test response".to_string(),
642                },
643                finish_reason: Some("stop".to_string()),
644                index: Some(0),
645            }],
646            usage: Some(Usage {
647                prompt_tokens: 10,
648                total_tokens: 20,
649            }),
650        };
651
652        let completion_response: completion::CompletionResponse<CompletionResponse> =
653            mira_response.try_into().unwrap();
654
655        assert_eq!(
656            completion_response.choice.first(),
657            completion::AssistantContent::text("Test response")
658        );
659    }
660}