rig/providers/
mira.rs

1//! Mira API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::mira;
6//!
7//! let client = mira::Client::new("YOUR_API_KEY");
8//!
9//! ```
10use crate::client::{CompletionClient, ProviderClient};
11use crate::json_utils::merge;
12use crate::providers::openai;
13use crate::providers::openai::send_compatible_streaming_request;
14use crate::streaming::StreamingCompletionResponse;
15use crate::{
16    OneOrMany,
17    completion::{self, CompletionError, CompletionRequest},
18    impl_conversion_traits,
19    message::{self, AssistantContent, Message, UserContent},
20};
21use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
22use serde::Deserialize;
23use serde_json::{Value, json};
24use std::string::FromUtf8Error;
25use thiserror::Error;
26use tracing;
27
28#[derive(Debug, Error)]
29pub enum MiraError {
30    #[error("Invalid API key")]
31    InvalidApiKey,
32    #[error("API error: {0}")]
33    ApiError(u16),
34    #[error("Request error: {0}")]
35    RequestError(#[from] reqwest::Error),
36    #[error("UTF-8 error: {0}")]
37    Utf8Error(#[from] FromUtf8Error),
38    #[error("JSON error: {0}")]
39    JsonError(#[from] serde_json::Error),
40}
41
42#[derive(Debug, Deserialize)]
43struct ApiErrorResponse {
44    message: String,
45}
46
47#[derive(Debug, Deserialize, Clone)]
48pub struct RawMessage {
49    pub role: String,
50    pub content: String,
51}
52
53const MIRA_API_BASE_URL: &str = "https://api.mira.network";
54
55impl TryFrom<RawMessage> for message::Message {
56    type Error = CompletionError;
57
58    fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
59        match raw.role.as_str() {
60            "user" => Ok(message::Message::User {
61                content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
62            }),
63            "assistant" => Ok(message::Message::Assistant {
64                id: None,
65                content: OneOrMany::one(AssistantContent::Text(message::Text {
66                    text: raw.content,
67                })),
68            }),
69            _ => Err(CompletionError::ResponseError(format!(
70                "Unsupported message role: {}",
71                raw.role
72            ))),
73        }
74    }
75}
76
77#[derive(Debug, Deserialize)]
78#[serde(untagged)]
79pub enum CompletionResponse {
80    Structured {
81        id: String,
82        object: String,
83        created: u64,
84        model: String,
85        choices: Vec<ChatChoice>,
86        #[serde(skip_serializing_if = "Option::is_none")]
87        usage: Option<Usage>,
88    },
89    Simple(String),
90}
91
92#[derive(Debug, Deserialize)]
93pub struct ChatChoice {
94    pub message: RawMessage,
95    #[serde(default)]
96    pub finish_reason: Option<String>,
97    #[serde(default)]
98    pub index: Option<usize>,
99}
100
101#[derive(Debug, Deserialize)]
102struct ModelsResponse {
103    data: Vec<ModelInfo>,
104}
105
106#[derive(Debug, Deserialize)]
107struct ModelInfo {
108    id: String,
109}
110
111#[derive(Clone)]
112/// Client for interacting with the Mira API
113pub struct Client {
114    base_url: String,
115    http_client: reqwest::Client,
116    api_key: String,
117    headers: HeaderMap,
118}
119
120impl std::fmt::Debug for Client {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        f.debug_struct("Client")
123            .field("base_url", &self.base_url)
124            .field("http_client", &self.http_client)
125            .field("api_key", &"<REDACTED>")
126            .field("headers", &self.headers)
127            .finish()
128    }
129}
130
131impl Client {
132    /// Create a new Mira client with the given API key
133    pub fn new(api_key: &str) -> Result<Self, MiraError> {
134        let mut headers = HeaderMap::new();
135        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
136        headers.insert(
137            reqwest::header::ACCEPT,
138            HeaderValue::from_static("application/json"),
139        );
140        headers.insert(
141            reqwest::header::USER_AGENT,
142            HeaderValue::from_static("rig-client/1.0"),
143        );
144
145        Ok(Self {
146            base_url: MIRA_API_BASE_URL.to_string(),
147            api_key: api_key.to_string(),
148            http_client: reqwest::Client::builder()
149                .build()
150                .expect("Failed to build HTTP client"),
151            headers,
152        })
153    }
154
155    /// Create a new Mira client with a custom base URL and API key
156    pub fn new_with_base_url(
157        api_key: &str,
158        base_url: impl Into<String>,
159    ) -> Result<Self, MiraError> {
160        let mut client = Self::new(api_key)?;
161        client.base_url = base_url.into();
162        Ok(client)
163    }
164
165    /// Use your own `reqwest::Client`.
166    /// The required headers will be automatically attached upon trying to make a request.
167    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
168        self.http_client = client;
169
170        self
171    }
172
173    /// List available models
174    pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
175        let url = format!("{}/v1/models", self.base_url);
176
177        let response = self
178            .http_client
179            .get(&url)
180            .bearer_auth(&self.api_key)
181            .headers(self.headers.clone())
182            .send()
183            .await?;
184
185        let status = response.status();
186
187        if !status.is_success() {
188            // Log the error text but don't store it in an unused variable
189            let _error_text = response.text().await.unwrap_or_default();
190            tracing::error!("Error response: {}", _error_text);
191            return Err(MiraError::ApiError(status.as_u16()));
192        }
193
194        let response_text = response.text().await?;
195
196        let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
197            tracing::error!("Failed to parse response: {}", e);
198            MiraError::JsonError(e)
199        })?;
200
201        Ok(models.data.into_iter().map(|model| model.id).collect())
202    }
203}
204
205impl ProviderClient for Client {
206    /// Create a new Mira client from the `MIRA_API_KEY` environment variable.
207    /// Panics if the environment variable is not set.
208    fn from_env() -> Self {
209        let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
210        Self::new(&api_key).expect("Could not create Mira Client")
211    }
212}
213
214impl CompletionClient for Client {
215    type CompletionModel = CompletionModel;
216    /// Create a completion model with the given name.
217    fn completion_model(&self, model: &str) -> CompletionModel {
218        CompletionModel::new(self.to_owned(), model)
219    }
220}
221
222impl_conversion_traits!(
223    AsEmbeddings,
224    AsTranscription,
225    AsImageGeneration,
226    AsAudioGeneration for Client
227);
228
229#[derive(Clone)]
230pub struct CompletionModel {
231    client: Client,
232    /// Name of the model
233    pub model: String,
234}
235
236impl CompletionModel {
237    pub fn new(client: Client, model: &str) -> Self {
238        Self {
239            client,
240            model: model.to_string(),
241        }
242    }
243
244    fn create_completion_request(
245        &self,
246        completion_request: CompletionRequest,
247    ) -> Result<Value, CompletionError> {
248        let mut messages = Vec::new();
249
250        // Add preamble as user message if available
251        if let Some(preamble) = &completion_request.preamble {
252            messages.push(serde_json::json!({
253                "role": "user",
254                "content": preamble.to_string()
255            }));
256        }
257
258        // Add docs
259        if let Some(Message::User { content }) = completion_request.normalized_documents() {
260            let text = content
261                .into_iter()
262                .filter_map(|doc| match doc {
263                    UserContent::Document(doc) => Some(doc.data),
264                    UserContent::Text(text) => Some(text.text),
265
266                    // This should always be `Document`
267                    _ => None,
268                })
269                .collect::<Vec<_>>()
270                .join("\n");
271
272            messages.push(serde_json::json!({
273                "role": "user",
274                "content": text
275            }));
276        }
277
278        // Add chat history
279        for msg in completion_request.chat_history {
280            let (role, content) = match msg {
281                Message::User { content } => {
282                    let text = content
283                        .iter()
284                        .map(|c| match c {
285                            UserContent::Text(text) => &text.text,
286                            _ => "",
287                        })
288                        .collect::<Vec<_>>()
289                        .join("\n");
290                    ("user", text)
291                }
292                Message::Assistant { content, .. } => {
293                    let text = content
294                        .iter()
295                        .map(|c| match c {
296                            AssistantContent::Text(text) => &text.text,
297                            _ => "",
298                        })
299                        .collect::<Vec<_>>()
300                        .join("\n");
301                    ("assistant", text)
302                }
303            };
304            messages.push(serde_json::json!({
305                "role": role,
306                "content": content
307            }));
308        }
309
310        let request = serde_json::json!({
311            "model": self.model,
312            "messages": messages,
313            "temperature": completion_request.temperature.map(|t| t as f32).unwrap_or(0.7),
314            "max_tokens": completion_request.max_tokens.map(|t| t as u32).unwrap_or(100),
315            "stream": false
316        });
317
318        Ok(request)
319    }
320}
321
322impl completion::CompletionModel for CompletionModel {
323    type Response = CompletionResponse;
324    type StreamingResponse = openai::StreamingCompletionResponse;
325
326    #[cfg_attr(feature = "worker", worker::send)]
327    async fn completion(
328        &self,
329        completion_request: CompletionRequest,
330    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
331        if !completion_request.tools.is_empty() {
332            tracing::warn!(target: "rig",
333                "Tool calls are not supported by the Mira provider. {} tools will be ignored.",
334                completion_request.tools.len()
335            );
336        }
337
338        let mira_request = self.create_completion_request(completion_request)?;
339
340        let response = self
341            .client
342            .http_client
343            .post(format!("{}/v1/chat/completions", self.client.base_url))
344            .bearer_auth(&self.client.api_key)
345            .headers(self.client.headers.clone())
346            .json(&mira_request)
347            .send()
348            .await
349            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
350
351        if !response.status().is_success() {
352            let status = response.status().as_u16();
353            let error_text = response.text().await.unwrap_or_default();
354            return Err(CompletionError::ProviderError(format!(
355                "API error: {status} - {error_text}"
356            )));
357        }
358
359        let response: CompletionResponse = response
360            .json()
361            .await
362            .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
363
364        response.try_into()
365    }
366
367    #[cfg_attr(feature = "worker", worker::send)]
368    async fn stream(
369        &self,
370        completion_request: CompletionRequest,
371    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
372        let mut request = self.create_completion_request(completion_request)?;
373
374        request = merge(request, json!({"stream": true}));
375
376        let builder = self
377            .client
378            .http_client
379            .post(format!("{}/v1/chat/completions", self.client.base_url))
380            .headers(self.client.headers.clone())
381            .json(&request);
382
383        send_compatible_streaming_request(builder).await
384    }
385}
386
387impl From<ApiErrorResponse> for CompletionError {
388    fn from(err: ApiErrorResponse) -> Self {
389        CompletionError::ProviderError(err.message)
390    }
391}
392
393impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
394    type Error = CompletionError;
395
396    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
397        let content = match &response {
398            CompletionResponse::Structured { choices, .. } => {
399                let choice = choices.first().ok_or_else(|| {
400                    CompletionError::ResponseError("Response contained no choices".to_owned())
401                })?;
402
403                // Convert RawMessage to message::Message
404                let message = message::Message::try_from(choice.message.clone())?;
405
406                match message {
407                    Message::Assistant { content, .. } => {
408                        if content.is_empty() {
409                            return Err(CompletionError::ResponseError(
410                                "Response contained empty content".to_owned(),
411                            ));
412                        }
413
414                        // Log warning for unsupported content types
415                        for c in content.iter() {
416                            if !matches!(c, AssistantContent::Text(_)) {
417                                tracing::warn!(target: "rig",
418                                    "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
419                                );
420                            }
421                        }
422
423                        content.iter().map(|c| {
424                            match c {
425                                AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
426                                other => Err(CompletionError::ResponseError(
427                                    format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
428                                ))
429                            }
430                        }).collect::<Result<Vec<_>, _>>()?
431                    }
432                    Message::User { .. } => {
433                        tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
434                        return Err(CompletionError::ResponseError(
435                            "Received user message in response where assistant message was expected".to_owned()
436                        ));
437                    }
438                }
439            }
440            CompletionResponse::Simple(text) => {
441                vec![completion::AssistantContent::text(text)]
442            }
443        };
444
445        let choice = OneOrMany::many(content).map_err(|_| {
446            CompletionError::ResponseError(
447                "Response contained no message or tool call (empty)".to_owned(),
448            )
449        })?;
450
451        Ok(completion::CompletionResponse {
452            choice,
453            raw_response: response,
454        })
455    }
456}
457
458#[derive(Clone, Debug, Deserialize)]
459pub struct Usage {
460    pub prompt_tokens: usize,
461    pub total_tokens: usize,
462}
463
464impl std::fmt::Display for Usage {
465    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
466        write!(
467            f,
468            "Prompt tokens: {} Total tokens: {}",
469            self.prompt_tokens, self.total_tokens
470        )
471    }
472}
473
474impl From<Message> for serde_json::Value {
475    fn from(msg: Message) -> Self {
476        match msg {
477            Message::User { content } => {
478                let text = content
479                    .iter()
480                    .map(|c| match c {
481                        UserContent::Text(text) => &text.text,
482                        _ => "",
483                    })
484                    .collect::<Vec<_>>()
485                    .join("\n");
486                serde_json::json!({
487                    "role": "user",
488                    "content": text
489                })
490            }
491            Message::Assistant { content, .. } => {
492                let text = content
493                    .iter()
494                    .map(|c| match c {
495                        AssistantContent::Text(text) => &text.text,
496                        _ => "",
497                    })
498                    .collect::<Vec<_>>()
499                    .join("\n");
500                serde_json::json!({
501                    "role": "assistant",
502                    "content": text
503                })
504            }
505        }
506    }
507}
508
509impl TryFrom<serde_json::Value> for Message {
510    type Error = CompletionError;
511
512    fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
513        let role = value["role"].as_str().ok_or_else(|| {
514            CompletionError::ResponseError("Message missing role field".to_owned())
515        })?;
516
517        // Handle both string and array content formats
518        let content = match value.get("content") {
519            Some(content) => match content {
520                serde_json::Value::String(s) => s.clone(),
521                serde_json::Value::Array(arr) => arr
522                    .iter()
523                    .filter_map(|c| {
524                        c.get("text")
525                            .and_then(|t| t.as_str())
526                            .map(|text| text.to_string())
527                    })
528                    .collect::<Vec<_>>()
529                    .join("\n"),
530                _ => {
531                    return Err(CompletionError::ResponseError(
532                        "Message content must be string or array".to_owned(),
533                    ));
534                }
535            },
536            None => {
537                return Err(CompletionError::ResponseError(
538                    "Message missing content field".to_owned(),
539                ));
540            }
541        };
542
543        match role {
544            "user" => Ok(Message::User {
545                content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
546            }),
547            "assistant" => Ok(Message::Assistant {
548                id: None,
549                content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
550            }),
551            _ => Err(CompletionError::ResponseError(format!(
552                "Unsupported message role: {role}"
553            ))),
554        }
555    }
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561    use crate::message::UserContent;
562    use serde_json::json;
563
564    #[test]
565    fn test_deserialize_message() {
566        // Test string content format
567        let assistant_message_json = json!({
568            "role": "assistant",
569            "content": "Hello there, how may I assist you today?"
570        });
571
572        let user_message_json = json!({
573            "role": "user",
574            "content": "What can you help me with?"
575        });
576
577        // Test array content format
578        let assistant_message_array_json = json!({
579            "role": "assistant",
580            "content": [{
581                "type": "text",
582                "text": "Hello there, how may I assist you today?"
583            }]
584        });
585
586        let assistant_message = Message::try_from(assistant_message_json).unwrap();
587        let user_message = Message::try_from(user_message_json).unwrap();
588        let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
589
590        // Test string content format
591        match assistant_message {
592            Message::Assistant { content, .. } => {
593                assert_eq!(
594                    content.first(),
595                    AssistantContent::Text(message::Text {
596                        text: "Hello there, how may I assist you today?".to_string()
597                    })
598                );
599            }
600            _ => panic!("Expected assistant message"),
601        }
602
603        match user_message {
604            Message::User { content } => {
605                assert_eq!(
606                    content.first(),
607                    UserContent::Text(message::Text {
608                        text: "What can you help me with?".to_string()
609                    })
610                );
611            }
612            _ => panic!("Expected user message"),
613        }
614
615        // Test array content format
616        match assistant_message_array {
617            Message::Assistant { content, .. } => {
618                assert_eq!(
619                    content.first(),
620                    AssistantContent::Text(message::Text {
621                        text: "Hello there, how may I assist you today?".to_string()
622                    })
623                );
624            }
625            _ => panic!("Expected assistant message"),
626        }
627    }
628
629    #[test]
630    fn test_message_conversion() {
631        // Test converting from our Message type to Mira's format and back
632        let original_message = message::Message::User {
633            content: OneOrMany::one(message::UserContent::text("Hello")),
634        };
635
636        // Convert to Mira format
637        let mira_value: serde_json::Value = original_message.clone().into();
638
639        // Convert back to our Message type
640        let converted_message: Message = mira_value.try_into().unwrap();
641
642        assert_eq!(original_message, converted_message);
643    }
644
645    #[test]
646    fn test_completion_response_conversion() {
647        let mira_response = CompletionResponse::Structured {
648            id: "resp_123".to_string(),
649            object: "chat.completion".to_string(),
650            created: 1234567890,
651            model: "deepseek-r1".to_string(),
652            choices: vec![ChatChoice {
653                message: RawMessage {
654                    role: "assistant".to_string(),
655                    content: "Test response".to_string(),
656                },
657                finish_reason: Some("stop".to_string()),
658                index: Some(0),
659            }],
660            usage: Some(Usage {
661                prompt_tokens: 10,
662                total_tokens: 20,
663            }),
664        };
665
666        let completion_response: completion::CompletionResponse<CompletionResponse> =
667            mira_response.try_into().unwrap();
668
669        assert_eq!(
670            completion_response.choice.first(),
671            completion::AssistantContent::text("Test response")
672        );
673    }
674}