rig/providers/
mira.rs

1//! Mira API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::mira;
6//!
7//! let client = mira::Client::new("YOUR_API_KEY");
8//!
9//! ```
10use crate::client::{
11    ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError,
12};
13use crate::json_utils::merge;
14use crate::providers::openai;
15use crate::providers::openai::send_compatible_streaming_request;
16use crate::streaming::StreamingCompletionResponse;
17use crate::{
18    OneOrMany,
19    completion::{self, CompletionError, CompletionRequest},
20    impl_conversion_traits,
21    message::{self, AssistantContent, Message, UserContent},
22};
23use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
24use serde::{Deserialize, Serialize};
25use serde_json::{Value, json};
26use std::string::FromUtf8Error;
27use thiserror::Error;
28use tracing;
29
30#[derive(Debug, Error)]
31pub enum MiraError {
32    #[error("Invalid API key")]
33    InvalidApiKey,
34    #[error("API error: {0}")]
35    ApiError(u16),
36    #[error("Request error: {0}")]
37    RequestError(#[from] reqwest::Error),
38    #[error("UTF-8 error: {0}")]
39    Utf8Error(#[from] FromUtf8Error),
40    #[error("JSON error: {0}")]
41    JsonError(#[from] serde_json::Error),
42}
43
44#[derive(Debug, Deserialize)]
45struct ApiErrorResponse {
46    message: String,
47}
48
49#[derive(Debug, Deserialize, Clone, Serialize)]
50pub struct RawMessage {
51    pub role: String,
52    pub content: String,
53}
54
55const MIRA_API_BASE_URL: &str = "https://api.mira.network";
56
57impl TryFrom<RawMessage> for message::Message {
58    type Error = CompletionError;
59
60    fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
61        match raw.role.as_str() {
62            "user" => Ok(message::Message::User {
63                content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
64            }),
65            "assistant" => Ok(message::Message::Assistant {
66                id: None,
67                content: OneOrMany::one(AssistantContent::Text(message::Text {
68                    text: raw.content,
69                })),
70            }),
71            _ => Err(CompletionError::ResponseError(format!(
72                "Unsupported message role: {}",
73                raw.role
74            ))),
75        }
76    }
77}
78
79#[derive(Debug, Deserialize, Serialize)]
80#[serde(untagged)]
81pub enum CompletionResponse {
82    Structured {
83        id: String,
84        object: String,
85        created: u64,
86        model: String,
87        choices: Vec<ChatChoice>,
88        #[serde(skip_serializing_if = "Option::is_none")]
89        usage: Option<Usage>,
90    },
91    Simple(String),
92}
93
94#[derive(Debug, Deserialize, Serialize)]
95pub struct ChatChoice {
96    pub message: RawMessage,
97    #[serde(default)]
98    pub finish_reason: Option<String>,
99    #[serde(default)]
100    pub index: Option<usize>,
101}
102
103#[derive(Debug, Deserialize, Serialize)]
104struct ModelsResponse {
105    data: Vec<ModelInfo>,
106}
107
108#[derive(Debug, Deserialize, Serialize)]
109struct ModelInfo {
110    id: String,
111}
112
113pub struct ClientBuilder<'a> {
114    api_key: &'a str,
115    base_url: &'a str,
116    http_client: Option<reqwest::Client>,
117}
118
119impl<'a> ClientBuilder<'a> {
120    pub fn new(api_key: &'a str) -> Self {
121        Self {
122            api_key,
123            base_url: MIRA_API_BASE_URL,
124            http_client: None,
125        }
126    }
127
128    pub fn base_url(mut self, base_url: &'a str) -> Self {
129        self.base_url = base_url;
130        self
131    }
132
133    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
134        self.http_client = Some(client);
135        self
136    }
137
138    pub fn build(self) -> Result<Client, ClientBuilderError> {
139        let mut headers = HeaderMap::new();
140        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
141        headers.insert(
142            reqwest::header::ACCEPT,
143            HeaderValue::from_static("application/json"),
144        );
145        headers.insert(
146            reqwest::header::USER_AGENT,
147            HeaderValue::from_static("rig-client/1.0"),
148        );
149        let http_client = if let Some(http_client) = self.http_client {
150            http_client
151        } else {
152            reqwest::Client::builder().build()?
153        };
154
155        Ok(Client {
156            base_url: self.base_url.to_string(),
157            http_client,
158            api_key: self.api_key.to_string(),
159            headers,
160        })
161    }
162}
163
164#[derive(Clone)]
165/// Client for interacting with the Mira API
166pub struct Client {
167    base_url: String,
168    http_client: reqwest::Client,
169    api_key: String,
170    headers: HeaderMap,
171}
172
173impl std::fmt::Debug for Client {
174    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175        f.debug_struct("Client")
176            .field("base_url", &self.base_url)
177            .field("http_client", &self.http_client)
178            .field("api_key", &"<REDACTED>")
179            .field("headers", &self.headers)
180            .finish()
181    }
182}
183
184impl Client {
185    /// Create a new Mira client builder.
186    ///
187    /// # Example
188    /// ```
189    /// use rig::providers::mira::{ClientBuilder, self};
190    ///
191    /// // Initialize the Mira client
192    /// let mira = Client::builder("your-mira-api-key")
193    ///    .build()
194    /// ```
195    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
196        ClientBuilder::new(api_key)
197    }
198
199    /// Create a new Mira client. For more control, use the `builder` method.
200    ///
201    /// # Panics
202    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
203    pub fn new(api_key: &str) -> Self {
204        Self::builder(api_key)
205            .build()
206            .expect("Mira client should build")
207    }
208
209    /// List available models
210    pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
211        let response = self.get("/v1/models").send().await?;
212
213        let status = response.status();
214
215        if !status.is_success() {
216            // Log the error text but don't store it in an unused variable
217            let _error_text = response.text().await.unwrap_or_default();
218            tracing::error!("Error response: {}", _error_text);
219            return Err(MiraError::ApiError(status.as_u16()));
220        }
221
222        let response_text = response.text().await?;
223
224        let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
225            tracing::error!("Failed to parse response: {}", e);
226            MiraError::JsonError(e)
227        })?;
228
229        Ok(models.data.into_iter().map(|model| model.id).collect())
230    }
231
232    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
233        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
234        self.http_client
235            .post(url)
236            .bearer_auth(&self.api_key)
237            .headers(self.headers.clone())
238    }
239
240    pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
241        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
242        self.http_client
243            .get(url)
244            .bearer_auth(&self.api_key)
245            .headers(self.headers.clone())
246    }
247}
248
249impl ProviderClient for Client {
250    /// Create a new Mira client from the `MIRA_API_KEY` environment variable.
251    /// Panics if the environment variable is not set.
252    fn from_env() -> Self {
253        let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
254        Self::new(&api_key)
255    }
256
257    fn from_val(input: crate::client::ProviderValue) -> Self {
258        let crate::client::ProviderValue::Simple(api_key) = input else {
259            panic!("Incorrect provider value type")
260        };
261        Self::new(&api_key)
262    }
263}
264
265impl CompletionClient for Client {
266    type CompletionModel = CompletionModel;
267    /// Create a completion model with the given name.
268    fn completion_model(&self, model: &str) -> CompletionModel {
269        CompletionModel::new(self.to_owned(), model)
270    }
271}
272
273impl VerifyClient for Client {
274    #[cfg_attr(feature = "worker", worker::send)]
275    async fn verify(&self) -> Result<(), VerifyError> {
276        let response = self.get("/user-credits").send().await?;
277        match response.status() {
278            reqwest::StatusCode::OK => Ok(()),
279            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
280            reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
281                Err(VerifyError::ProviderError(response.text().await?))
282            }
283            _ => {
284                response.error_for_status()?;
285                Ok(())
286            }
287        }
288    }
289}
290
291impl_conversion_traits!(
292    AsEmbeddings,
293    AsTranscription,
294    AsImageGeneration,
295    AsAudioGeneration for Client
296);
297
298#[derive(Clone)]
299pub struct CompletionModel {
300    client: Client,
301    /// Name of the model
302    pub model: String,
303}
304
305impl CompletionModel {
306    pub fn new(client: Client, model: &str) -> Self {
307        Self {
308            client,
309            model: model.to_string(),
310        }
311    }
312
313    fn create_completion_request(
314        &self,
315        completion_request: CompletionRequest,
316    ) -> Result<Value, CompletionError> {
317        let mut messages = Vec::new();
318
319        // Add preamble as user message if available
320        if let Some(preamble) = &completion_request.preamble {
321            messages.push(serde_json::json!({
322                "role": "user",
323                "content": preamble.to_string()
324            }));
325        }
326
327        // Add docs
328        if let Some(Message::User { content }) = completion_request.normalized_documents() {
329            let text = content
330                .into_iter()
331                .filter_map(|doc| match doc {
332                    UserContent::Document(doc) => Some(doc.data),
333                    UserContent::Text(text) => Some(text.text),
334
335                    // This should always be `Document`
336                    _ => None,
337                })
338                .collect::<Vec<_>>()
339                .join("\n");
340
341            messages.push(serde_json::json!({
342                "role": "user",
343                "content": text
344            }));
345        }
346
347        // Add chat history
348        for msg in completion_request.chat_history {
349            let (role, content) = match msg {
350                Message::User { content } => {
351                    let text = content
352                        .iter()
353                        .map(|c| match c {
354                            UserContent::Text(text) => &text.text,
355                            _ => "",
356                        })
357                        .collect::<Vec<_>>()
358                        .join("\n");
359                    ("user", text)
360                }
361                Message::Assistant { content, .. } => {
362                    let text = content
363                        .iter()
364                        .map(|c| match c {
365                            AssistantContent::Text(text) => &text.text,
366                            _ => "",
367                        })
368                        .collect::<Vec<_>>()
369                        .join("\n");
370                    ("assistant", text)
371                }
372            };
373            messages.push(serde_json::json!({
374                "role": role,
375                "content": content
376            }));
377        }
378
379        let request = serde_json::json!({
380            "model": self.model,
381            "messages": messages,
382            "temperature": completion_request.temperature.map(|t| t as f32).unwrap_or(0.7),
383            "max_tokens": completion_request.max_tokens.map(|t| t as u32).unwrap_or(100),
384            "stream": false
385        });
386
387        Ok(request)
388    }
389}
390
391impl completion::CompletionModel for CompletionModel {
392    type Response = CompletionResponse;
393    type StreamingResponse = openai::StreamingCompletionResponse;
394
395    #[cfg_attr(feature = "worker", worker::send)]
396    async fn completion(
397        &self,
398        completion_request: CompletionRequest,
399    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
400        if !completion_request.tools.is_empty() {
401            tracing::warn!(target: "rig",
402                "Tool calls are not supported by the Mira provider. {} tools will be ignored.",
403                completion_request.tools.len()
404            );
405        }
406
407        let mira_request = self.create_completion_request(completion_request)?;
408
409        let response = self
410            .client
411            .post("/v1/chat/completions")
412            .json(&mira_request)
413            .send()
414            .await
415            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
416
417        if !response.status().is_success() {
418            let status = response.status().as_u16();
419            let error_text = response.text().await.unwrap_or_default();
420            return Err(CompletionError::ProviderError(format!(
421                "API error: {status} - {error_text}"
422            )));
423        }
424
425        let response: CompletionResponse = response
426            .json()
427            .await
428            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
429
430        response.try_into()
431    }
432
433    #[cfg_attr(feature = "worker", worker::send)]
434    async fn stream(
435        &self,
436        completion_request: CompletionRequest,
437    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
438        let mut request = self.create_completion_request(completion_request)?;
439
440        request = merge(request, json!({"stream": true}));
441
442        let builder = self.client.post("/v1/chat/completions").json(&request);
443
444        send_compatible_streaming_request(builder).await
445    }
446}
447
448impl From<ApiErrorResponse> for CompletionError {
449    fn from(err: ApiErrorResponse) -> Self {
450        CompletionError::ProviderError(err.message)
451    }
452}
453
454impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
455    type Error = CompletionError;
456
457    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
458        let (content, usage) = match &response {
459            CompletionResponse::Structured { choices, usage, .. } => {
460                let choice = choices.first().ok_or_else(|| {
461                    CompletionError::ResponseError("Response contained no choices".to_owned())
462                })?;
463
464                let usage = usage
465                    .as_ref()
466                    .map(|usage| completion::Usage {
467                        input_tokens: usage.prompt_tokens as u64,
468                        output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
469                        total_tokens: usage.total_tokens as u64,
470                    })
471                    .unwrap_or_default();
472
473                // Convert RawMessage to message::Message
474                let message = message::Message::try_from(choice.message.clone())?;
475
476                let content = match message {
477                    Message::Assistant { content, .. } => {
478                        if content.is_empty() {
479                            return Err(CompletionError::ResponseError(
480                                "Response contained empty content".to_owned(),
481                            ));
482                        }
483
484                        // Log warning for unsupported content types
485                        for c in content.iter() {
486                            if !matches!(c, AssistantContent::Text(_)) {
487                                tracing::warn!(target: "rig",
488                                    "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
489                                );
490                            }
491                        }
492
493                        content.iter().map(|c| {
494                            match c {
495                                AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
496                                other => Err(CompletionError::ResponseError(
497                                    format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
498                                ))
499                            }
500                        }).collect::<Result<Vec<_>, _>>()?
501                    }
502                    Message::User { .. } => {
503                        tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
504                        return Err(CompletionError::ResponseError(
505                            "Received user message in response where assistant message was expected".to_owned()
506                        ));
507                    }
508                };
509
510                (content, usage)
511            }
512            CompletionResponse::Simple(text) => (
513                vec![completion::AssistantContent::text(text)],
514                completion::Usage::new(),
515            ),
516        };
517
518        let choice = OneOrMany::many(content).map_err(|_| {
519            CompletionError::ResponseError(
520                "Response contained no message or tool call (empty)".to_owned(),
521            )
522        })?;
523
524        Ok(completion::CompletionResponse {
525            choice,
526            usage,
527            raw_response: response,
528        })
529    }
530}
531
532#[derive(Clone, Debug, Deserialize, Serialize)]
533pub struct Usage {
534    pub prompt_tokens: usize,
535    pub total_tokens: usize,
536}
537
538impl std::fmt::Display for Usage {
539    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
540        write!(
541            f,
542            "Prompt tokens: {} Total tokens: {}",
543            self.prompt_tokens, self.total_tokens
544        )
545    }
546}
547
548impl From<Message> for serde_json::Value {
549    fn from(msg: Message) -> Self {
550        match msg {
551            Message::User { content } => {
552                let text = content
553                    .iter()
554                    .map(|c| match c {
555                        UserContent::Text(text) => &text.text,
556                        _ => "",
557                    })
558                    .collect::<Vec<_>>()
559                    .join("\n");
560                serde_json::json!({
561                    "role": "user",
562                    "content": text
563                })
564            }
565            Message::Assistant { content, .. } => {
566                let text = content
567                    .iter()
568                    .map(|c| match c {
569                        AssistantContent::Text(text) => &text.text,
570                        _ => "",
571                    })
572                    .collect::<Vec<_>>()
573                    .join("\n");
574                serde_json::json!({
575                    "role": "assistant",
576                    "content": text
577                })
578            }
579        }
580    }
581}
582
583impl TryFrom<serde_json::Value> for Message {
584    type Error = CompletionError;
585
586    fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
587        let role = value["role"].as_str().ok_or_else(|| {
588            CompletionError::ResponseError("Message missing role field".to_owned())
589        })?;
590
591        // Handle both string and array content formats
592        let content = match value.get("content") {
593            Some(content) => match content {
594                serde_json::Value::String(s) => s.clone(),
595                serde_json::Value::Array(arr) => arr
596                    .iter()
597                    .filter_map(|c| {
598                        c.get("text")
599                            .and_then(|t| t.as_str())
600                            .map(|text| text.to_string())
601                    })
602                    .collect::<Vec<_>>()
603                    .join("\n"),
604                _ => {
605                    return Err(CompletionError::ResponseError(
606                        "Message content must be string or array".to_owned(),
607                    ));
608                }
609            },
610            None => {
611                return Err(CompletionError::ResponseError(
612                    "Message missing content field".to_owned(),
613                ));
614            }
615        };
616
617        match role {
618            "user" => Ok(Message::User {
619                content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
620            }),
621            "assistant" => Ok(Message::Assistant {
622                id: None,
623                content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
624            }),
625            _ => Err(CompletionError::ResponseError(format!(
626                "Unsupported message role: {role}"
627            ))),
628        }
629    }
630}
631
632#[cfg(test)]
633mod tests {
634    use super::*;
635    use crate::message::UserContent;
636    use serde_json::json;
637
638    #[test]
639    fn test_deserialize_message() {
640        // Test string content format
641        let assistant_message_json = json!({
642            "role": "assistant",
643            "content": "Hello there, how may I assist you today?"
644        });
645
646        let user_message_json = json!({
647            "role": "user",
648            "content": "What can you help me with?"
649        });
650
651        // Test array content format
652        let assistant_message_array_json = json!({
653            "role": "assistant",
654            "content": [{
655                "type": "text",
656                "text": "Hello there, how may I assist you today?"
657            }]
658        });
659
660        let assistant_message = Message::try_from(assistant_message_json).unwrap();
661        let user_message = Message::try_from(user_message_json).unwrap();
662        let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
663
664        // Test string content format
665        match assistant_message {
666            Message::Assistant { content, .. } => {
667                assert_eq!(
668                    content.first(),
669                    AssistantContent::Text(message::Text {
670                        text: "Hello there, how may I assist you today?".to_string()
671                    })
672                );
673            }
674            _ => panic!("Expected assistant message"),
675        }
676
677        match user_message {
678            Message::User { content } => {
679                assert_eq!(
680                    content.first(),
681                    UserContent::Text(message::Text {
682                        text: "What can you help me with?".to_string()
683                    })
684                );
685            }
686            _ => panic!("Expected user message"),
687        }
688
689        // Test array content format
690        match assistant_message_array {
691            Message::Assistant { content, .. } => {
692                assert_eq!(
693                    content.first(),
694                    AssistantContent::Text(message::Text {
695                        text: "Hello there, how may I assist you today?".to_string()
696                    })
697                );
698            }
699            _ => panic!("Expected assistant message"),
700        }
701    }
702
703    #[test]
704    fn test_message_conversion() {
705        // Test converting from our Message type to Mira's format and back
706        let original_message = message::Message::User {
707            content: OneOrMany::one(message::UserContent::text("Hello")),
708        };
709
710        // Convert to Mira format
711        let mira_value: serde_json::Value = original_message.clone().into();
712
713        // Convert back to our Message type
714        let converted_message: Message = mira_value.try_into().unwrap();
715
716        assert_eq!(original_message, converted_message);
717    }
718
719    #[test]
720    fn test_completion_response_conversion() {
721        let mira_response = CompletionResponse::Structured {
722            id: "resp_123".to_string(),
723            object: "chat.completion".to_string(),
724            created: 1234567890,
725            model: "deepseek-r1".to_string(),
726            choices: vec![ChatChoice {
727                message: RawMessage {
728                    role: "assistant".to_string(),
729                    content: "Test response".to_string(),
730                },
731                finish_reason: Some("stop".to_string()),
732                index: Some(0),
733            }],
734            usage: Some(Usage {
735                prompt_tokens: 10,
736                total_tokens: 20,
737            }),
738        };
739
740        let completion_response: completion::CompletionResponse<CompletionResponse> =
741            mira_response.try_into().unwrap();
742
743        assert_eq!(
744            completion_response.choice.first(),
745            completion::AssistantContent::text("Test response")
746        );
747    }
748}