rig/providers/
mira.rs

1//! Mira API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::mira;
6//!
7//! let client = mira::Client::new("YOUR_API_KEY");
8//!
9//! ```
10use crate::client::{CompletionClient, ProviderClient};
11use crate::json_utils::merge;
12use crate::providers::openai;
13use crate::providers::openai::send_compatible_streaming_request;
14use crate::streaming::StreamingCompletionResponse;
15use crate::{
16    OneOrMany,
17    completion::{self, CompletionError, CompletionRequest},
18    impl_conversion_traits,
19    message::{self, AssistantContent, Message, UserContent},
20};
21use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
22use serde::Deserialize;
23use serde_json::{Value, json};
24use std::string::FromUtf8Error;
25use thiserror::Error;
26use tracing;
27
28#[derive(Debug, Error)]
29pub enum MiraError {
30    #[error("Invalid API key")]
31    InvalidApiKey,
32    #[error("API error: {0}")]
33    ApiError(u16),
34    #[error("Request error: {0}")]
35    RequestError(#[from] reqwest::Error),
36    #[error("UTF-8 error: {0}")]
37    Utf8Error(#[from] FromUtf8Error),
38    #[error("JSON error: {0}")]
39    JsonError(#[from] serde_json::Error),
40}
41
42#[derive(Debug, Deserialize)]
43struct ApiErrorResponse {
44    message: String,
45}
46
47#[derive(Debug, Deserialize, Clone)]
48pub struct RawMessage {
49    pub role: String,
50    pub content: String,
51}
52
53const MIRA_API_BASE_URL: &str = "https://api.mira.network";
54
55impl TryFrom<RawMessage> for message::Message {
56    type Error = CompletionError;
57
58    fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
59        match raw.role.as_str() {
60            "user" => Ok(message::Message::User {
61                content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
62            }),
63            "assistant" => Ok(message::Message::Assistant {
64                id: None,
65                content: OneOrMany::one(AssistantContent::Text(message::Text {
66                    text: raw.content,
67                })),
68            }),
69            _ => Err(CompletionError::ResponseError(format!(
70                "Unsupported message role: {}",
71                raw.role
72            ))),
73        }
74    }
75}
76
77#[derive(Debug, Deserialize)]
78#[serde(untagged)]
79pub enum CompletionResponse {
80    Structured {
81        id: String,
82        object: String,
83        created: u64,
84        model: String,
85        choices: Vec<ChatChoice>,
86        #[serde(skip_serializing_if = "Option::is_none")]
87        usage: Option<Usage>,
88    },
89    Simple(String),
90}
91
92#[derive(Debug, Deserialize)]
93pub struct ChatChoice {
94    pub message: RawMessage,
95    #[serde(default)]
96    pub finish_reason: Option<String>,
97    #[serde(default)]
98    pub index: Option<usize>,
99}
100
101#[derive(Debug, Deserialize)]
102struct ModelsResponse {
103    data: Vec<ModelInfo>,
104}
105
106#[derive(Debug, Deserialize)]
107struct ModelInfo {
108    id: String,
109}
110
111#[derive(Clone)]
112/// Client for interacting with the Mira API
113pub struct Client {
114    base_url: String,
115    http_client: reqwest::Client,
116    api_key: String,
117    headers: HeaderMap,
118}
119
120impl std::fmt::Debug for Client {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        f.debug_struct("Client")
123            .field("base_url", &self.base_url)
124            .field("http_client", &self.http_client)
125            .field("api_key", &"<REDACTED>")
126            .field("headers", &self.headers)
127            .finish()
128    }
129}
130
131impl Client {
132    /// Create a new Mira client with the given API key
133    pub fn new(api_key: &str) -> Result<Self, MiraError> {
134        let mut headers = HeaderMap::new();
135        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
136        headers.insert(
137            reqwest::header::ACCEPT,
138            HeaderValue::from_static("application/json"),
139        );
140        headers.insert(
141            reqwest::header::USER_AGENT,
142            HeaderValue::from_static("rig-client/1.0"),
143        );
144
145        Ok(Self {
146            base_url: MIRA_API_BASE_URL.to_string(),
147            api_key: api_key.to_string(),
148            http_client: reqwest::Client::builder()
149                .build()
150                .expect("Failed to build HTTP client"),
151            headers,
152        })
153    }
154
155    /// Create a new Mira client with a custom base URL and API key
156    pub fn new_with_base_url(
157        api_key: &str,
158        base_url: impl Into<String>,
159    ) -> Result<Self, MiraError> {
160        let mut client = Self::new(api_key)?;
161        client.base_url = base_url.into();
162        Ok(client)
163    }
164
165    /// Use your own `reqwest::Client`.
166    /// The required headers will be automatically attached upon trying to make a request.
167    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
168        self.http_client = client;
169
170        self
171    }
172
173    /// List available models
174    pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
175        let url = format!("{}/v1/models", self.base_url);
176
177        let response = self
178            .http_client
179            .get(&url)
180            .bearer_auth(&self.api_key)
181            .headers(self.headers.clone())
182            .send()
183            .await?;
184
185        let status = response.status();
186
187        if !status.is_success() {
188            // Log the error text but don't store it in an unused variable
189            let _error_text = response.text().await.unwrap_or_default();
190            tracing::error!("Error response: {}", _error_text);
191            return Err(MiraError::ApiError(status.as_u16()));
192        }
193
194        let response_text = response.text().await?;
195
196        let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
197            tracing::error!("Failed to parse response: {}", e);
198            MiraError::JsonError(e)
199        })?;
200
201        Ok(models.data.into_iter().map(|model| model.id).collect())
202    }
203}
204
205impl ProviderClient for Client {
206    /// Create a new Mira client from the `MIRA_API_KEY` environment variable.
207    /// Panics if the environment variable is not set.
208    fn from_env() -> Self {
209        let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
210        Self::new(&api_key).expect("Could not create Mira Client")
211    }
212
213    fn from_val(input: crate::client::ProviderValue) -> Self {
214        let crate::client::ProviderValue::Simple(api_key) = input else {
215            panic!("Incorrect provider value type")
216        };
217        Self::new(&api_key).unwrap()
218    }
219}
220
221impl CompletionClient for Client {
222    type CompletionModel = CompletionModel;
223    /// Create a completion model with the given name.
224    fn completion_model(&self, model: &str) -> CompletionModel {
225        CompletionModel::new(self.to_owned(), model)
226    }
227}
228
229impl_conversion_traits!(
230    AsEmbeddings,
231    AsTranscription,
232    AsImageGeneration,
233    AsAudioGeneration for Client
234);
235
236#[derive(Clone)]
237pub struct CompletionModel {
238    client: Client,
239    /// Name of the model
240    pub model: String,
241}
242
243impl CompletionModel {
244    pub fn new(client: Client, model: &str) -> Self {
245        Self {
246            client,
247            model: model.to_string(),
248        }
249    }
250
251    fn create_completion_request(
252        &self,
253        completion_request: CompletionRequest,
254    ) -> Result<Value, CompletionError> {
255        let mut messages = Vec::new();
256
257        // Add preamble as user message if available
258        if let Some(preamble) = &completion_request.preamble {
259            messages.push(serde_json::json!({
260                "role": "user",
261                "content": preamble.to_string()
262            }));
263        }
264
265        // Add docs
266        if let Some(Message::User { content }) = completion_request.normalized_documents() {
267            let text = content
268                .into_iter()
269                .filter_map(|doc| match doc {
270                    UserContent::Document(doc) => Some(doc.data),
271                    UserContent::Text(text) => Some(text.text),
272
273                    // This should always be `Document`
274                    _ => None,
275                })
276                .collect::<Vec<_>>()
277                .join("\n");
278
279            messages.push(serde_json::json!({
280                "role": "user",
281                "content": text
282            }));
283        }
284
285        // Add chat history
286        for msg in completion_request.chat_history {
287            let (role, content) = match msg {
288                Message::User { content } => {
289                    let text = content
290                        .iter()
291                        .map(|c| match c {
292                            UserContent::Text(text) => &text.text,
293                            _ => "",
294                        })
295                        .collect::<Vec<_>>()
296                        .join("\n");
297                    ("user", text)
298                }
299                Message::Assistant { content, .. } => {
300                    let text = content
301                        .iter()
302                        .map(|c| match c {
303                            AssistantContent::Text(text) => &text.text,
304                            _ => "",
305                        })
306                        .collect::<Vec<_>>()
307                        .join("\n");
308                    ("assistant", text)
309                }
310            };
311            messages.push(serde_json::json!({
312                "role": role,
313                "content": content
314            }));
315        }
316
317        let request = serde_json::json!({
318            "model": self.model,
319            "messages": messages,
320            "temperature": completion_request.temperature.map(|t| t as f32).unwrap_or(0.7),
321            "max_tokens": completion_request.max_tokens.map(|t| t as u32).unwrap_or(100),
322            "stream": false
323        });
324
325        Ok(request)
326    }
327}
328
329impl completion::CompletionModel for CompletionModel {
330    type Response = CompletionResponse;
331    type StreamingResponse = openai::StreamingCompletionResponse;
332
333    #[cfg_attr(feature = "worker", worker::send)]
334    async fn completion(
335        &self,
336        completion_request: CompletionRequest,
337    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
338        if !completion_request.tools.is_empty() {
339            tracing::warn!(target: "rig",
340                "Tool calls are not supported by the Mira provider. {} tools will be ignored.",
341                completion_request.tools.len()
342            );
343        }
344
345        let mira_request = self.create_completion_request(completion_request)?;
346
347        let response = self
348            .client
349            .http_client
350            .post(format!("{}/v1/chat/completions", self.client.base_url))
351            .bearer_auth(&self.client.api_key)
352            .headers(self.client.headers.clone())
353            .json(&mira_request)
354            .send()
355            .await
356            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
357
358        if !response.status().is_success() {
359            let status = response.status().as_u16();
360            let error_text = response.text().await.unwrap_or_default();
361            return Err(CompletionError::ProviderError(format!(
362                "API error: {status} - {error_text}"
363            )));
364        }
365
366        let response: CompletionResponse = response
367            .json()
368            .await
369            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
370
371        response.try_into()
372    }
373
374    #[cfg_attr(feature = "worker", worker::send)]
375    async fn stream(
376        &self,
377        completion_request: CompletionRequest,
378    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
379        let mut request = self.create_completion_request(completion_request)?;
380
381        request = merge(request, json!({"stream": true}));
382
383        let builder = self
384            .client
385            .http_client
386            .post(format!("{}/v1/chat/completions", self.client.base_url))
387            .headers(self.client.headers.clone())
388            .json(&request);
389
390        send_compatible_streaming_request(builder).await
391    }
392}
393
394impl From<ApiErrorResponse> for CompletionError {
395    fn from(err: ApiErrorResponse) -> Self {
396        CompletionError::ProviderError(err.message)
397    }
398}
399
400impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
401    type Error = CompletionError;
402
403    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
404        let (content, usage) = match &response {
405            CompletionResponse::Structured { choices, usage, .. } => {
406                let choice = choices.first().ok_or_else(|| {
407                    CompletionError::ResponseError("Response contained no choices".to_owned())
408                })?;
409
410                let usage = usage
411                    .as_ref()
412                    .map(|usage| completion::Usage {
413                        input_tokens: usage.prompt_tokens as u64,
414                        output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
415                        total_tokens: usage.total_tokens as u64,
416                    })
417                    .unwrap_or_default();
418
419                // Convert RawMessage to message::Message
420                let message = message::Message::try_from(choice.message.clone())?;
421
422                let content = match message {
423                    Message::Assistant { content, .. } => {
424                        if content.is_empty() {
425                            return Err(CompletionError::ResponseError(
426                                "Response contained empty content".to_owned(),
427                            ));
428                        }
429
430                        // Log warning for unsupported content types
431                        for c in content.iter() {
432                            if !matches!(c, AssistantContent::Text(_)) {
433                                tracing::warn!(target: "rig",
434                                    "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
435                                );
436                            }
437                        }
438
439                        content.iter().map(|c| {
440                            match c {
441                                AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
442                                other => Err(CompletionError::ResponseError(
443                                    format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
444                                ))
445                            }
446                        }).collect::<Result<Vec<_>, _>>()?
447                    }
448                    Message::User { .. } => {
449                        tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
450                        return Err(CompletionError::ResponseError(
451                            "Received user message in response where assistant message was expected".to_owned()
452                        ));
453                    }
454                };
455
456                (content, usage)
457            }
458            CompletionResponse::Simple(text) => (
459                vec![completion::AssistantContent::text(text)],
460                completion::Usage::new(),
461            ),
462        };
463
464        let choice = OneOrMany::many(content).map_err(|_| {
465            CompletionError::ResponseError(
466                "Response contained no message or tool call (empty)".to_owned(),
467            )
468        })?;
469
470        Ok(completion::CompletionResponse {
471            choice,
472            usage,
473            raw_response: response,
474        })
475    }
476}
477
478#[derive(Clone, Debug, Deserialize)]
479pub struct Usage {
480    pub prompt_tokens: usize,
481    pub total_tokens: usize,
482}
483
484impl std::fmt::Display for Usage {
485    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
486        write!(
487            f,
488            "Prompt tokens: {} Total tokens: {}",
489            self.prompt_tokens, self.total_tokens
490        )
491    }
492}
493
494impl From<Message> for serde_json::Value {
495    fn from(msg: Message) -> Self {
496        match msg {
497            Message::User { content } => {
498                let text = content
499                    .iter()
500                    .map(|c| match c {
501                        UserContent::Text(text) => &text.text,
502                        _ => "",
503                    })
504                    .collect::<Vec<_>>()
505                    .join("\n");
506                serde_json::json!({
507                    "role": "user",
508                    "content": text
509                })
510            }
511            Message::Assistant { content, .. } => {
512                let text = content
513                    .iter()
514                    .map(|c| match c {
515                        AssistantContent::Text(text) => &text.text,
516                        _ => "",
517                    })
518                    .collect::<Vec<_>>()
519                    .join("\n");
520                serde_json::json!({
521                    "role": "assistant",
522                    "content": text
523                })
524            }
525        }
526    }
527}
528
529impl TryFrom<serde_json::Value> for Message {
530    type Error = CompletionError;
531
532    fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
533        let role = value["role"].as_str().ok_or_else(|| {
534            CompletionError::ResponseError("Message missing role field".to_owned())
535        })?;
536
537        // Handle both string and array content formats
538        let content = match value.get("content") {
539            Some(content) => match content {
540                serde_json::Value::String(s) => s.clone(),
541                serde_json::Value::Array(arr) => arr
542                    .iter()
543                    .filter_map(|c| {
544                        c.get("text")
545                            .and_then(|t| t.as_str())
546                            .map(|text| text.to_string())
547                    })
548                    .collect::<Vec<_>>()
549                    .join("\n"),
550                _ => {
551                    return Err(CompletionError::ResponseError(
552                        "Message content must be string or array".to_owned(),
553                    ));
554                }
555            },
556            None => {
557                return Err(CompletionError::ResponseError(
558                    "Message missing content field".to_owned(),
559                ));
560            }
561        };
562
563        match role {
564            "user" => Ok(Message::User {
565                content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
566            }),
567            "assistant" => Ok(Message::Assistant {
568                id: None,
569                content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
570            }),
571            _ => Err(CompletionError::ResponseError(format!(
572                "Unsupported message role: {role}"
573            ))),
574        }
575    }
576}
577
578#[cfg(test)]
579mod tests {
580    use super::*;
581    use crate::message::UserContent;
582    use serde_json::json;
583
584    #[test]
585    fn test_deserialize_message() {
586        // Test string content format
587        let assistant_message_json = json!({
588            "role": "assistant",
589            "content": "Hello there, how may I assist you today?"
590        });
591
592        let user_message_json = json!({
593            "role": "user",
594            "content": "What can you help me with?"
595        });
596
597        // Test array content format
598        let assistant_message_array_json = json!({
599            "role": "assistant",
600            "content": [{
601                "type": "text",
602                "text": "Hello there, how may I assist you today?"
603            }]
604        });
605
606        let assistant_message = Message::try_from(assistant_message_json).unwrap();
607        let user_message = Message::try_from(user_message_json).unwrap();
608        let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
609
610        // Test string content format
611        match assistant_message {
612            Message::Assistant { content, .. } => {
613                assert_eq!(
614                    content.first(),
615                    AssistantContent::Text(message::Text {
616                        text: "Hello there, how may I assist you today?".to_string()
617                    })
618                );
619            }
620            _ => panic!("Expected assistant message"),
621        }
622
623        match user_message {
624            Message::User { content } => {
625                assert_eq!(
626                    content.first(),
627                    UserContent::Text(message::Text {
628                        text: "What can you help me with?".to_string()
629                    })
630                );
631            }
632            _ => panic!("Expected user message"),
633        }
634
635        // Test array content format
636        match assistant_message_array {
637            Message::Assistant { content, .. } => {
638                assert_eq!(
639                    content.first(),
640                    AssistantContent::Text(message::Text {
641                        text: "Hello there, how may I assist you today?".to_string()
642                    })
643                );
644            }
645            _ => panic!("Expected assistant message"),
646        }
647    }
648
649    #[test]
650    fn test_message_conversion() {
651        // Test converting from our Message type to Mira's format and back
652        let original_message = message::Message::User {
653            content: OneOrMany::one(message::UserContent::text("Hello")),
654        };
655
656        // Convert to Mira format
657        let mira_value: serde_json::Value = original_message.clone().into();
658
659        // Convert back to our Message type
660        let converted_message: Message = mira_value.try_into().unwrap();
661
662        assert_eq!(original_message, converted_message);
663    }
664
665    #[test]
666    fn test_completion_response_conversion() {
667        let mira_response = CompletionResponse::Structured {
668            id: "resp_123".to_string(),
669            object: "chat.completion".to_string(),
670            created: 1234567890,
671            model: "deepseek-r1".to_string(),
672            choices: vec![ChatChoice {
673                message: RawMessage {
674                    role: "assistant".to_string(),
675                    content: "Test response".to_string(),
676                },
677                finish_reason: Some("stop".to_string()),
678                index: Some(0),
679            }],
680            usage: Some(Usage {
681                prompt_tokens: 10,
682                total_tokens: 20,
683            }),
684        };
685
686        let completion_response: completion::CompletionResponse<CompletionResponse> =
687            mira_response.try_into().unwrap();
688
689        assert_eq!(
690            completion_response.choice.first(),
691            completion::AssistantContent::text("Test response")
692        );
693    }
694}