rig/providers/
mira.rs

1//! Mira API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::mira;
6//!
7//! let client = mira::Client::new("YOUR_API_KEY");
8//!
9//! ```
10use crate::client::{ClientBuilderError, CompletionClient, ProviderClient};
11use crate::json_utils::merge;
12use crate::providers::openai;
13use crate::providers::openai::send_compatible_streaming_request;
14use crate::streaming::StreamingCompletionResponse;
15use crate::{
16    OneOrMany,
17    completion::{self, CompletionError, CompletionRequest},
18    impl_conversion_traits,
19    message::{self, AssistantContent, Message, UserContent},
20};
21use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
22use serde::{Deserialize, Serialize};
23use serde_json::{Value, json};
24use std::string::FromUtf8Error;
25use thiserror::Error;
26use tracing;
27
28#[derive(Debug, Error)]
29pub enum MiraError {
30    #[error("Invalid API key")]
31    InvalidApiKey,
32    #[error("API error: {0}")]
33    ApiError(u16),
34    #[error("Request error: {0}")]
35    RequestError(#[from] reqwest::Error),
36    #[error("UTF-8 error: {0}")]
37    Utf8Error(#[from] FromUtf8Error),
38    #[error("JSON error: {0}")]
39    JsonError(#[from] serde_json::Error),
40}
41
42#[derive(Debug, Deserialize)]
43struct ApiErrorResponse {
44    message: String,
45}
46
47#[derive(Debug, Deserialize, Clone, Serialize)]
48pub struct RawMessage {
49    pub role: String,
50    pub content: String,
51}
52
53const MIRA_API_BASE_URL: &str = "https://api.mira.network";
54
55impl TryFrom<RawMessage> for message::Message {
56    type Error = CompletionError;
57
58    fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
59        match raw.role.as_str() {
60            "user" => Ok(message::Message::User {
61                content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
62            }),
63            "assistant" => Ok(message::Message::Assistant {
64                id: None,
65                content: OneOrMany::one(AssistantContent::Text(message::Text {
66                    text: raw.content,
67                })),
68            }),
69            _ => Err(CompletionError::ResponseError(format!(
70                "Unsupported message role: {}",
71                raw.role
72            ))),
73        }
74    }
75}
76
77#[derive(Debug, Deserialize, Serialize)]
78#[serde(untagged)]
79pub enum CompletionResponse {
80    Structured {
81        id: String,
82        object: String,
83        created: u64,
84        model: String,
85        choices: Vec<ChatChoice>,
86        #[serde(skip_serializing_if = "Option::is_none")]
87        usage: Option<Usage>,
88    },
89    Simple(String),
90}
91
92#[derive(Debug, Deserialize, Serialize)]
93pub struct ChatChoice {
94    pub message: RawMessage,
95    #[serde(default)]
96    pub finish_reason: Option<String>,
97    #[serde(default)]
98    pub index: Option<usize>,
99}
100
101#[derive(Debug, Deserialize, Serialize)]
102struct ModelsResponse {
103    data: Vec<ModelInfo>,
104}
105
106#[derive(Debug, Deserialize, Serialize)]
107struct ModelInfo {
108    id: String,
109}
110
111pub struct ClientBuilder<'a> {
112    api_key: &'a str,
113    base_url: &'a str,
114    http_client: Option<reqwest::Client>,
115}
116
117impl<'a> ClientBuilder<'a> {
118    pub fn new(api_key: &'a str) -> Self {
119        Self {
120            api_key,
121            base_url: MIRA_API_BASE_URL,
122            http_client: None,
123        }
124    }
125
126    pub fn base_url(mut self, base_url: &'a str) -> Self {
127        self.base_url = base_url;
128        self
129    }
130
131    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
132        self.http_client = Some(client);
133        self
134    }
135
136    pub fn build(self) -> Result<Client, ClientBuilderError> {
137        let mut headers = HeaderMap::new();
138        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
139        headers.insert(
140            reqwest::header::ACCEPT,
141            HeaderValue::from_static("application/json"),
142        );
143        headers.insert(
144            reqwest::header::USER_AGENT,
145            HeaderValue::from_static("rig-client/1.0"),
146        );
147        let http_client = if let Some(http_client) = self.http_client {
148            http_client
149        } else {
150            reqwest::Client::builder().build()?
151        };
152
153        Ok(Client {
154            base_url: self.base_url.to_string(),
155            http_client,
156            api_key: self.api_key.to_string(),
157            headers,
158        })
159    }
160}
161
162#[derive(Clone)]
163/// Client for interacting with the Mira API
164pub struct Client {
165    base_url: String,
166    http_client: reqwest::Client,
167    api_key: String,
168    headers: HeaderMap,
169}
170
171impl std::fmt::Debug for Client {
172    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173        f.debug_struct("Client")
174            .field("base_url", &self.base_url)
175            .field("http_client", &self.http_client)
176            .field("api_key", &"<REDACTED>")
177            .field("headers", &self.headers)
178            .finish()
179    }
180}
181
182impl Client {
183    /// Create a new Mira client builder.
184    ///
185    /// # Example
186    /// ```
187    /// use rig::providers::mira::{ClientBuilder, self};
188    ///
189    /// // Initialize the Mira client
190    /// let mira = Client::builder("your-mira-api-key")
191    ///    .build()
192    /// ```
193    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
194        ClientBuilder::new(api_key)
195    }
196
197    /// Create a new Mira client. For more control, use the `builder` method.
198    ///
199    /// # Panics
200    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
201    pub fn new(api_key: &str) -> Self {
202        Self::builder(api_key)
203            .build()
204            .expect("Mira client should build")
205    }
206
207    /// List available models
208    pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
209        let url = format!("{}/v1/models", self.base_url);
210
211        let response = self
212            .http_client
213            .get(&url)
214            .bearer_auth(&self.api_key)
215            .headers(self.headers.clone())
216            .send()
217            .await?;
218
219        let status = response.status();
220
221        if !status.is_success() {
222            // Log the error text but don't store it in an unused variable
223            let _error_text = response.text().await.unwrap_or_default();
224            tracing::error!("Error response: {}", _error_text);
225            return Err(MiraError::ApiError(status.as_u16()));
226        }
227
228        let response_text = response.text().await?;
229
230        let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
231            tracing::error!("Failed to parse response: {}", e);
232            MiraError::JsonError(e)
233        })?;
234
235        Ok(models.data.into_iter().map(|model| model.id).collect())
236    }
237}
238
239impl ProviderClient for Client {
240    /// Create a new Mira client from the `MIRA_API_KEY` environment variable.
241    /// Panics if the environment variable is not set.
242    fn from_env() -> Self {
243        let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
244        Self::new(&api_key)
245    }
246
247    fn from_val(input: crate::client::ProviderValue) -> Self {
248        let crate::client::ProviderValue::Simple(api_key) = input else {
249            panic!("Incorrect provider value type")
250        };
251        Self::new(&api_key)
252    }
253}
254
255impl CompletionClient for Client {
256    type CompletionModel = CompletionModel;
257    /// Create a completion model with the given name.
258    fn completion_model(&self, model: &str) -> CompletionModel {
259        CompletionModel::new(self.to_owned(), model)
260    }
261}
262
263impl_conversion_traits!(
264    AsEmbeddings,
265    AsTranscription,
266    AsImageGeneration,
267    AsAudioGeneration for Client
268);
269
270#[derive(Clone)]
271pub struct CompletionModel {
272    client: Client,
273    /// Name of the model
274    pub model: String,
275}
276
277impl CompletionModel {
278    pub fn new(client: Client, model: &str) -> Self {
279        Self {
280            client,
281            model: model.to_string(),
282        }
283    }
284
285    fn create_completion_request(
286        &self,
287        completion_request: CompletionRequest,
288    ) -> Result<Value, CompletionError> {
289        let mut messages = Vec::new();
290
291        // Add preamble as user message if available
292        if let Some(preamble) = &completion_request.preamble {
293            messages.push(serde_json::json!({
294                "role": "user",
295                "content": preamble.to_string()
296            }));
297        }
298
299        // Add docs
300        if let Some(Message::User { content }) = completion_request.normalized_documents() {
301            let text = content
302                .into_iter()
303                .filter_map(|doc| match doc {
304                    UserContent::Document(doc) => Some(doc.data),
305                    UserContent::Text(text) => Some(text.text),
306
307                    // This should always be `Document`
308                    _ => None,
309                })
310                .collect::<Vec<_>>()
311                .join("\n");
312
313            messages.push(serde_json::json!({
314                "role": "user",
315                "content": text
316            }));
317        }
318
319        // Add chat history
320        for msg in completion_request.chat_history {
321            let (role, content) = match msg {
322                Message::User { content } => {
323                    let text = content
324                        .iter()
325                        .map(|c| match c {
326                            UserContent::Text(text) => &text.text,
327                            _ => "",
328                        })
329                        .collect::<Vec<_>>()
330                        .join("\n");
331                    ("user", text)
332                }
333                Message::Assistant { content, .. } => {
334                    let text = content
335                        .iter()
336                        .map(|c| match c {
337                            AssistantContent::Text(text) => &text.text,
338                            _ => "",
339                        })
340                        .collect::<Vec<_>>()
341                        .join("\n");
342                    ("assistant", text)
343                }
344            };
345            messages.push(serde_json::json!({
346                "role": role,
347                "content": content
348            }));
349        }
350
351        let request = serde_json::json!({
352            "model": self.model,
353            "messages": messages,
354            "temperature": completion_request.temperature.map(|t| t as f32).unwrap_or(0.7),
355            "max_tokens": completion_request.max_tokens.map(|t| t as u32).unwrap_or(100),
356            "stream": false
357        });
358
359        Ok(request)
360    }
361}
362
363impl completion::CompletionModel for CompletionModel {
364    type Response = CompletionResponse;
365    type StreamingResponse = openai::StreamingCompletionResponse;
366
367    #[cfg_attr(feature = "worker", worker::send)]
368    async fn completion(
369        &self,
370        completion_request: CompletionRequest,
371    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
372        if !completion_request.tools.is_empty() {
373            tracing::warn!(target: "rig",
374                "Tool calls are not supported by the Mira provider. {} tools will be ignored.",
375                completion_request.tools.len()
376            );
377        }
378
379        let mira_request = self.create_completion_request(completion_request)?;
380
381        let response = self
382            .client
383            .http_client
384            .post(format!("{}/v1/chat/completions", self.client.base_url))
385            .bearer_auth(&self.client.api_key)
386            .headers(self.client.headers.clone())
387            .json(&mira_request)
388            .send()
389            .await
390            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
391
392        if !response.status().is_success() {
393            let status = response.status().as_u16();
394            let error_text = response.text().await.unwrap_or_default();
395            return Err(CompletionError::ProviderError(format!(
396                "API error: {status} - {error_text}"
397            )));
398        }
399
400        let response: CompletionResponse = response
401            .json()
402            .await
403            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
404
405        response.try_into()
406    }
407
408    #[cfg_attr(feature = "worker", worker::send)]
409    async fn stream(
410        &self,
411        completion_request: CompletionRequest,
412    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
413        let mut request = self.create_completion_request(completion_request)?;
414
415        request = merge(request, json!({"stream": true}));
416
417        let builder = self
418            .client
419            .http_client
420            .post(format!("{}/v1/chat/completions", self.client.base_url))
421            .headers(self.client.headers.clone())
422            .json(&request);
423
424        send_compatible_streaming_request(builder).await
425    }
426}
427
428impl From<ApiErrorResponse> for CompletionError {
429    fn from(err: ApiErrorResponse) -> Self {
430        CompletionError::ProviderError(err.message)
431    }
432}
433
434impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
435    type Error = CompletionError;
436
437    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
438        let (content, usage) = match &response {
439            CompletionResponse::Structured { choices, usage, .. } => {
440                let choice = choices.first().ok_or_else(|| {
441                    CompletionError::ResponseError("Response contained no choices".to_owned())
442                })?;
443
444                let usage = usage
445                    .as_ref()
446                    .map(|usage| completion::Usage {
447                        input_tokens: usage.prompt_tokens as u64,
448                        output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
449                        total_tokens: usage.total_tokens as u64,
450                    })
451                    .unwrap_or_default();
452
453                // Convert RawMessage to message::Message
454                let message = message::Message::try_from(choice.message.clone())?;
455
456                let content = match message {
457                    Message::Assistant { content, .. } => {
458                        if content.is_empty() {
459                            return Err(CompletionError::ResponseError(
460                                "Response contained empty content".to_owned(),
461                            ));
462                        }
463
464                        // Log warning for unsupported content types
465                        for c in content.iter() {
466                            if !matches!(c, AssistantContent::Text(_)) {
467                                tracing::warn!(target: "rig",
468                                    "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
469                                );
470                            }
471                        }
472
473                        content.iter().map(|c| {
474                            match c {
475                                AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
476                                other => Err(CompletionError::ResponseError(
477                                    format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
478                                ))
479                            }
480                        }).collect::<Result<Vec<_>, _>>()?
481                    }
482                    Message::User { .. } => {
483                        tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
484                        return Err(CompletionError::ResponseError(
485                            "Received user message in response where assistant message was expected".to_owned()
486                        ));
487                    }
488                };
489
490                (content, usage)
491            }
492            CompletionResponse::Simple(text) => (
493                vec![completion::AssistantContent::text(text)],
494                completion::Usage::new(),
495            ),
496        };
497
498        let choice = OneOrMany::many(content).map_err(|_| {
499            CompletionError::ResponseError(
500                "Response contained no message or tool call (empty)".to_owned(),
501            )
502        })?;
503
504        Ok(completion::CompletionResponse {
505            choice,
506            usage,
507            raw_response: response,
508        })
509    }
510}
511
512#[derive(Clone, Debug, Deserialize, Serialize)]
513pub struct Usage {
514    pub prompt_tokens: usize,
515    pub total_tokens: usize,
516}
517
518impl std::fmt::Display for Usage {
519    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
520        write!(
521            f,
522            "Prompt tokens: {} Total tokens: {}",
523            self.prompt_tokens, self.total_tokens
524        )
525    }
526}
527
528impl From<Message> for serde_json::Value {
529    fn from(msg: Message) -> Self {
530        match msg {
531            Message::User { content } => {
532                let text = content
533                    .iter()
534                    .map(|c| match c {
535                        UserContent::Text(text) => &text.text,
536                        _ => "",
537                    })
538                    .collect::<Vec<_>>()
539                    .join("\n");
540                serde_json::json!({
541                    "role": "user",
542                    "content": text
543                })
544            }
545            Message::Assistant { content, .. } => {
546                let text = content
547                    .iter()
548                    .map(|c| match c {
549                        AssistantContent::Text(text) => &text.text,
550                        _ => "",
551                    })
552                    .collect::<Vec<_>>()
553                    .join("\n");
554                serde_json::json!({
555                    "role": "assistant",
556                    "content": text
557                })
558            }
559        }
560    }
561}
562
563impl TryFrom<serde_json::Value> for Message {
564    type Error = CompletionError;
565
566    fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
567        let role = value["role"].as_str().ok_or_else(|| {
568            CompletionError::ResponseError("Message missing role field".to_owned())
569        })?;
570
571        // Handle both string and array content formats
572        let content = match value.get("content") {
573            Some(content) => match content {
574                serde_json::Value::String(s) => s.clone(),
575                serde_json::Value::Array(arr) => arr
576                    .iter()
577                    .filter_map(|c| {
578                        c.get("text")
579                            .and_then(|t| t.as_str())
580                            .map(|text| text.to_string())
581                    })
582                    .collect::<Vec<_>>()
583                    .join("\n"),
584                _ => {
585                    return Err(CompletionError::ResponseError(
586                        "Message content must be string or array".to_owned(),
587                    ));
588                }
589            },
590            None => {
591                return Err(CompletionError::ResponseError(
592                    "Message missing content field".to_owned(),
593                ));
594            }
595        };
596
597        match role {
598            "user" => Ok(Message::User {
599                content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
600            }),
601            "assistant" => Ok(Message::Assistant {
602                id: None,
603                content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
604            }),
605            _ => Err(CompletionError::ResponseError(format!(
606                "Unsupported message role: {role}"
607            ))),
608        }
609    }
610}
611
612#[cfg(test)]
613mod tests {
614    use super::*;
615    use crate::message::UserContent;
616    use serde_json::json;
617
618    #[test]
619    fn test_deserialize_message() {
620        // Test string content format
621        let assistant_message_json = json!({
622            "role": "assistant",
623            "content": "Hello there, how may I assist you today?"
624        });
625
626        let user_message_json = json!({
627            "role": "user",
628            "content": "What can you help me with?"
629        });
630
631        // Test array content format
632        let assistant_message_array_json = json!({
633            "role": "assistant",
634            "content": [{
635                "type": "text",
636                "text": "Hello there, how may I assist you today?"
637            }]
638        });
639
640        let assistant_message = Message::try_from(assistant_message_json).unwrap();
641        let user_message = Message::try_from(user_message_json).unwrap();
642        let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
643
644        // Test string content format
645        match assistant_message {
646            Message::Assistant { content, .. } => {
647                assert_eq!(
648                    content.first(),
649                    AssistantContent::Text(message::Text {
650                        text: "Hello there, how may I assist you today?".to_string()
651                    })
652                );
653            }
654            _ => panic!("Expected assistant message"),
655        }
656
657        match user_message {
658            Message::User { content } => {
659                assert_eq!(
660                    content.first(),
661                    UserContent::Text(message::Text {
662                        text: "What can you help me with?".to_string()
663                    })
664                );
665            }
666            _ => panic!("Expected user message"),
667        }
668
669        // Test array content format
670        match assistant_message_array {
671            Message::Assistant { content, .. } => {
672                assert_eq!(
673                    content.first(),
674                    AssistantContent::Text(message::Text {
675                        text: "Hello there, how may I assist you today?".to_string()
676                    })
677                );
678            }
679            _ => panic!("Expected assistant message"),
680        }
681    }
682
683    #[test]
684    fn test_message_conversion() {
685        // Test converting from our Message type to Mira's format and back
686        let original_message = message::Message::User {
687            content: OneOrMany::one(message::UserContent::text("Hello")),
688        };
689
690        // Convert to Mira format
691        let mira_value: serde_json::Value = original_message.clone().into();
692
693        // Convert back to our Message type
694        let converted_message: Message = mira_value.try_into().unwrap();
695
696        assert_eq!(original_message, converted_message);
697    }
698
699    #[test]
700    fn test_completion_response_conversion() {
701        let mira_response = CompletionResponse::Structured {
702            id: "resp_123".to_string(),
703            object: "chat.completion".to_string(),
704            created: 1234567890,
705            model: "deepseek-r1".to_string(),
706            choices: vec![ChatChoice {
707                message: RawMessage {
708                    role: "assistant".to_string(),
709                    content: "Test response".to_string(),
710                },
711                finish_reason: Some("stop".to_string()),
712                index: Some(0),
713            }],
714            usage: Some(Usage {
715                prompt_tokens: 10,
716                total_tokens: 20,
717            }),
718        };
719
720        let completion_response: completion::CompletionResponse<CompletionResponse> =
721            mira_response.try_into().unwrap();
722
723        assert_eq!(
724            completion_response.choice.first(),
725            completion::AssistantContent::text("Test response")
726        );
727    }
728}