Skip to main content

rig/providers/
mira.rs

1//! Mira API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::mira;
6//!
7//! let client = mira::Client::new("YOUR_API_KEY");
8//!
9//! ```
10use crate::client::{
11    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
12    ProviderClient,
13};
14use crate::http_client::{self, HttpClientExt};
15use crate::message::{Document, DocumentSourceKind};
16use crate::providers::openai;
17use crate::providers::openai::send_compatible_streaming_request;
18use crate::streaming::StreamingCompletionResponse;
19use crate::{
20    OneOrMany,
21    completion::{self, CompletionError, CompletionRequest},
22    message::{self, AssistantContent, Message, UserContent},
23};
24use serde::{Deserialize, Serialize};
25use std::string::FromUtf8Error;
26use thiserror::Error;
27use tracing::{self, Instrument, info_span};
28
29#[derive(Debug, Default, Clone, Copy)]
30pub struct MiraExt;
31#[derive(Debug, Default, Clone, Copy)]
32pub struct MiraBuilder;
33
34type MiraApiKey = BearerAuth;
35
36impl Provider for MiraExt {
37    type Builder = MiraBuilder;
38
39    const VERIFY_PATH: &'static str = "/user-credits";
40}
41
42impl<H> Capabilities<H> for MiraExt {
43    type Completion = Capable<CompletionModel<H>>;
44    type Embeddings = Nothing;
45    type Transcription = Nothing;
46    type ModelListing = Nothing;
47
48    #[cfg(feature = "image")]
49    type ImageGeneration = Nothing;
50
51    #[cfg(feature = "audio")]
52    type AudioGeneration = Nothing;
53}
54
55impl DebugExt for MiraExt {}
56
57impl ProviderBuilder for MiraBuilder {
58    type Extension<H>
59        = MiraExt
60    where
61        H: HttpClientExt;
62    type ApiKey = MiraApiKey;
63
64    const BASE_URL: &'static str = MIRA_API_BASE_URL;
65
66    fn build<H>(
67        _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
68    ) -> http_client::Result<Self::Extension<H>>
69    where
70        H: HttpClientExt,
71    {
72        Ok(MiraExt)
73    }
74}
75
76pub type Client<H = reqwest::Client> = client::Client<MiraExt, H>;
77pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<MiraBuilder, MiraApiKey, H>;
78
79#[derive(Debug, Error)]
80pub enum MiraError {
81    #[error("Invalid API key")]
82    InvalidApiKey,
83    #[error("API error: {0}")]
84    ApiError(u16),
85    #[error("Request error: {0}")]
86    RequestError(#[from] http_client::Error),
87    #[error("UTF-8 error: {0}")]
88    Utf8Error(#[from] FromUtf8Error),
89    #[error("JSON error: {0}")]
90    JsonError(#[from] serde_json::Error),
91}
92
93#[derive(Debug, Deserialize)]
94struct ApiErrorResponse {
95    message: String,
96}
97
98#[derive(Debug, Deserialize, Clone, Serialize)]
99pub struct RawMessage {
100    pub role: String,
101    pub content: String,
102}
103
104const MIRA_API_BASE_URL: &str = "https://api.mira.network";
105
106impl TryFrom<RawMessage> for message::Message {
107    type Error = CompletionError;
108
109    fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
110        match raw.role.as_str() {
111            "user" => Ok(message::Message::User {
112                content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
113            }),
114            "assistant" => Ok(message::Message::Assistant {
115                id: None,
116                content: OneOrMany::one(AssistantContent::Text(message::Text {
117                    text: raw.content,
118                })),
119            }),
120            _ => Err(CompletionError::ResponseError(format!(
121                "Unsupported message role: {}",
122                raw.role
123            ))),
124        }
125    }
126}
127
128#[derive(Debug, Deserialize, Serialize)]
129#[serde(untagged)]
130pub enum CompletionResponse {
131    Structured {
132        id: String,
133        object: String,
134        created: u64,
135        model: String,
136        choices: Vec<ChatChoice>,
137        #[serde(skip_serializing_if = "Option::is_none")]
138        usage: Option<Usage>,
139    },
140    Simple(String),
141}
142
143#[derive(Debug, Deserialize, Serialize)]
144pub struct ChatChoice {
145    pub message: RawMessage,
146    #[serde(default)]
147    pub finish_reason: Option<String>,
148    #[serde(default)]
149    pub index: Option<usize>,
150}
151
152#[derive(Debug, Deserialize, Serialize)]
153struct ModelsResponse {
154    data: Vec<ModelInfo>,
155}
156
157#[derive(Debug, Deserialize, Serialize)]
158struct ModelInfo {
159    id: String,
160}
161
162impl<T> Client<T>
163where
164    T: HttpClientExt + 'static,
165{
166    /// List available models
167    pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
168        let req = self.get("/v1/models").and_then(|req| {
169            req.body(http_client::NoBody)
170                .map_err(http_client::Error::Protocol)
171        })?;
172
173        let response = self.send(req).await?;
174
175        let status = response.status();
176
177        if !status.is_success() {
178            // Log the error text but don't store it in an unused variable
179            let error_text = http_client::text(response).await.unwrap_or_default();
180            tracing::error!("Error response: {}", error_text);
181            return Err(MiraError::ApiError(status.as_u16()));
182        }
183
184        let response_text = http_client::text(response).await?;
185
186        let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
187            tracing::error!("Failed to parse response: {}", e);
188            MiraError::JsonError(e)
189        })?;
190
191        Ok(models.data.into_iter().map(|model| model.id).collect())
192    }
193}
194
195impl ProviderClient for Client {
196    type Input = String;
197
198    /// Create a new Mira client from the `MIRA_API_KEY` environment variable.
199    /// Panics if the environment variable is not set.
200    fn from_env() -> Self {
201        let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
202        Self::new(&api_key).unwrap()
203    }
204
205    fn from_val(input: Self::Input) -> Self {
206        Self::new(&input).unwrap()
207    }
208}
209
210#[derive(Debug, Serialize, Deserialize)]
211pub(super) struct MiraCompletionRequest {
212    model: String,
213    pub messages: Vec<RawMessage>,
214    #[serde(skip_serializing_if = "Option::is_none")]
215    temperature: Option<f64>,
216    #[serde(skip_serializing_if = "Option::is_none")]
217    max_tokens: Option<u64>,
218    pub stream: bool,
219}
220
221impl TryFrom<(&str, CompletionRequest)> for MiraCompletionRequest {
222    type Error = CompletionError;
223
224    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
225        if req.output_schema.is_some() {
226            tracing::warn!("Structured outputs currently not supported for Mira");
227        }
228        let model = req.model.clone().unwrap_or_else(|| model.to_string());
229        let mut messages = Vec::new();
230
231        if let Some(content) = &req.preamble {
232            messages.push(RawMessage {
233                role: "user".to_string(),
234                content: content.to_string(),
235            });
236        }
237
238        if let Some(Message::User { content }) = req.normalized_documents() {
239            let text = content
240                .into_iter()
241                .filter_map(|doc| match doc {
242                    UserContent::Document(Document {
243                        data: DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data),
244                        ..
245                    }) => Some(data),
246                    UserContent::Text(text) => Some(text.text),
247
248                    // This should always be `Document`
249                    _ => None,
250                })
251                .collect::<Vec<_>>()
252                .join("\n");
253
254            messages.push(RawMessage {
255                role: "user".to_string(),
256                content: text,
257            });
258        }
259
260        for msg in req.chat_history {
261            let (role, content) = match msg {
262                Message::User { content } => {
263                    let text = content
264                        .iter()
265                        .map(|c| match c {
266                            UserContent::Text(text) => &text.text,
267                            _ => "",
268                        })
269                        .collect::<Vec<_>>()
270                        .join("\n");
271                    ("user", text)
272                }
273                Message::Assistant { content, .. } => {
274                    let text = content
275                        .iter()
276                        .map(|c| match c {
277                            AssistantContent::Text(text) => &text.text,
278                            _ => "",
279                        })
280                        .collect::<Vec<_>>()
281                        .join("\n");
282                    ("assistant", text)
283                }
284            };
285            messages.push(RawMessage {
286                role: role.to_string(),
287                content,
288            });
289        }
290
291        Ok(Self {
292            model: model.to_string(),
293            messages,
294            temperature: req.temperature,
295            max_tokens: req.max_tokens,
296            stream: false,
297        })
298    }
299}
300
301#[derive(Clone)]
302pub struct CompletionModel<T = reqwest::Client> {
303    client: Client<T>,
304    /// Name of the model
305    pub model: String,
306}
307
308impl<T> CompletionModel<T> {
309    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
310        Self {
311            client,
312            model: model.into(),
313        }
314    }
315}
316
317impl<T> completion::CompletionModel for CompletionModel<T>
318where
319    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
320{
321    type Response = CompletionResponse;
322    type StreamingResponse = openai::StreamingCompletionResponse;
323
324    type Client = Client<T>;
325
326    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
327        Self::new(client.clone(), model)
328    }
329
330    async fn completion(
331        &self,
332        completion_request: CompletionRequest,
333    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
334        let span = if tracing::Span::current().is_disabled() {
335            info_span!(
336                target: "rig::completions",
337                "chat",
338                gen_ai.operation.name = "chat",
339                gen_ai.provider.name = "mira",
340                gen_ai.request.model = self.model,
341                gen_ai.system_instructions = tracing::field::Empty,
342                gen_ai.response.id = tracing::field::Empty,
343                gen_ai.response.model = tracing::field::Empty,
344                gen_ai.usage.output_tokens = tracing::field::Empty,
345                gen_ai.usage.input_tokens = tracing::field::Empty,
346            )
347        } else {
348            tracing::Span::current()
349        };
350
351        span.record("gen_ai.system_instructions", &completion_request.preamble);
352
353        if !completion_request.tools.is_empty() {
354            tracing::warn!(target: "rig::completions",
355                "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
356                len = completion_request.tools.len()
357            );
358        }
359
360        if completion_request.tool_choice.is_some() {
361            tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
362        }
363
364        if completion_request.additional_params.is_some() {
365            tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
366        }
367
368        let request = MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
369
370        if tracing::enabled!(tracing::Level::TRACE) {
371            tracing::trace!(target: "rig::completions",
372                "Mira completion request: {}",
373                serde_json::to_string_pretty(&request)?
374            );
375        }
376
377        let body = serde_json::to_vec(&request)?;
378
379        let req = self
380            .client
381            .post("/v1/chat/completions")?
382            .body(body)
383            .map_err(http_client::Error::from)?;
384
385        let async_block = async move {
386            let response = self
387                .client
388                .send::<_, bytes::Bytes>(req)
389                .await
390                .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
391
392            let status = response.status();
393            let response_body = response.into_body().into_future().await?.to_vec();
394
395            if !status.is_success() {
396                let status = status.as_u16();
397                let error_text = String::from_utf8_lossy(&response_body).to_string();
398                return Err(CompletionError::ProviderError(format!(
399                    "API error: {status} - {error_text}"
400                )));
401            }
402
403            let response: CompletionResponse = serde_json::from_slice(&response_body)?;
404
405            if tracing::enabled!(tracing::Level::TRACE) {
406                tracing::trace!(target: "rig::completions",
407                    "Mira completion response: {}",
408                    serde_json::to_string_pretty(&response)?
409                );
410            }
411
412            if let CompletionResponse::Structured {
413                id, model, usage, ..
414            } = &response
415            {
416                let span = tracing::Span::current();
417                span.record("gen_ai.response.model_name", model);
418                span.record("gen_ai.response.id", id);
419                if let Some(usage) = usage {
420                    span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
421                    span.record(
422                        "gen_ai.usage.output_tokens",
423                        usage.total_tokens - usage.prompt_tokens,
424                    );
425                }
426            }
427
428            response.try_into()
429        };
430
431        async_block.instrument(span).await
432    }
433
434    async fn stream(
435        &self,
436        completion_request: CompletionRequest,
437    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
438        let span = if tracing::Span::current().is_disabled() {
439            info_span!(
440                target: "rig::completions",
441                "chat_streaming",
442                gen_ai.operation.name = "chat_streaming",
443                gen_ai.provider.name = "mira",
444                gen_ai.request.model = self.model,
445                gen_ai.system_instructions = tracing::field::Empty,
446                gen_ai.response.id = tracing::field::Empty,
447                gen_ai.response.model = tracing::field::Empty,
448                gen_ai.usage.output_tokens = tracing::field::Empty,
449                gen_ai.usage.input_tokens = tracing::field::Empty,
450            )
451        } else {
452            tracing::Span::current()
453        };
454
455        span.record("gen_ai.system_instructions", &completion_request.preamble);
456
457        if !completion_request.tools.is_empty() {
458            tracing::warn!(target: "rig::completions",
459                "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
460                len = completion_request.tools.len()
461            );
462        }
463
464        if completion_request.tool_choice.is_some() {
465            tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
466        }
467
468        if completion_request.additional_params.is_some() {
469            tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
470        }
471        let mut request =
472            MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
473        request.stream = true;
474
475        if tracing::enabled!(tracing::Level::TRACE) {
476            tracing::trace!(target: "rig::completions",
477                "Mira completion request: {}",
478                serde_json::to_string_pretty(&request)?
479            );
480        }
481
482        let body = serde_json::to_vec(&request)?;
483
484        let req = self
485            .client
486            .post("/v1/chat/completions")?
487            .body(body)
488            .map_err(http_client::Error::from)?;
489
490        send_compatible_streaming_request(self.client.clone(), req)
491            .instrument(span)
492            .await
493    }
494}
495
496impl From<ApiErrorResponse> for CompletionError {
497    fn from(err: ApiErrorResponse) -> Self {
498        CompletionError::ProviderError(err.message)
499    }
500}
501
502impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
503    type Error = CompletionError;
504
505    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
506        let (content, usage) = match &response {
507            CompletionResponse::Structured { choices, usage, .. } => {
508                let choice = choices.first().ok_or_else(|| {
509                    CompletionError::ResponseError("Response contained no choices".to_owned())
510                })?;
511
512                let usage = usage
513                    .as_ref()
514                    .map(|usage| completion::Usage {
515                        input_tokens: usage.prompt_tokens as u64,
516                        output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
517                        total_tokens: usage.total_tokens as u64,
518                        cached_input_tokens: 0,
519                    })
520                    .unwrap_or_default();
521
522                // Convert RawMessage to message::Message
523                let message = message::Message::try_from(choice.message.clone())?;
524
525                let content = match message {
526                    Message::Assistant { content, .. } => {
527                        if content.is_empty() {
528                            return Err(CompletionError::ResponseError(
529                                "Response contained empty content".to_owned(),
530                            ));
531                        }
532
533                        // Log warning for unsupported content types
534                        for c in content.iter() {
535                            if !matches!(c, AssistantContent::Text(_)) {
536                                tracing::warn!(target: "rig",
537                                    "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
538                                );
539                            }
540                        }
541
542                        content.iter().map(|c| {
543                            match c {
544                                AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
545                                other => Err(CompletionError::ResponseError(
546                                    format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
547                                ))
548                            }
549                        }).collect::<Result<Vec<_>, _>>()?
550                    }
551                    Message::User { .. } => {
552                        tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
553                        return Err(CompletionError::ResponseError(
554                            "Received user message in response where assistant message was expected".to_owned()
555                        ));
556                    }
557                };
558
559                (content, usage)
560            }
561            CompletionResponse::Simple(text) => (
562                vec![completion::AssistantContent::text(text)],
563                completion::Usage::new(),
564            ),
565        };
566
567        let choice = OneOrMany::many(content).map_err(|_| {
568            CompletionError::ResponseError(
569                "Response contained no message or tool call (empty)".to_owned(),
570            )
571        })?;
572
573        Ok(completion::CompletionResponse {
574            choice,
575            usage,
576            raw_response: response,
577            message_id: None,
578        })
579    }
580}
581
582#[derive(Clone, Debug, Deserialize, Serialize)]
583pub struct Usage {
584    pub prompt_tokens: usize,
585    pub total_tokens: usize,
586}
587
588impl std::fmt::Display for Usage {
589    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
590        write!(
591            f,
592            "Prompt tokens: {} Total tokens: {}",
593            self.prompt_tokens, self.total_tokens
594        )
595    }
596}
597
598impl From<Message> for serde_json::Value {
599    fn from(msg: Message) -> Self {
600        match msg {
601            Message::User { content } => {
602                let text = content
603                    .iter()
604                    .map(|c| match c {
605                        UserContent::Text(text) => &text.text,
606                        _ => "",
607                    })
608                    .collect::<Vec<_>>()
609                    .join("\n");
610                serde_json::json!({
611                    "role": "user",
612                    "content": text
613                })
614            }
615            Message::Assistant { content, .. } => {
616                let text = content
617                    .iter()
618                    .map(|c| match c {
619                        AssistantContent::Text(text) => &text.text,
620                        _ => "",
621                    })
622                    .collect::<Vec<_>>()
623                    .join("\n");
624                serde_json::json!({
625                    "role": "assistant",
626                    "content": text
627                })
628            }
629        }
630    }
631}
632
633impl TryFrom<serde_json::Value> for Message {
634    type Error = CompletionError;
635
636    fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
637        let role = value["role"].as_str().ok_or_else(|| {
638            CompletionError::ResponseError("Message missing role field".to_owned())
639        })?;
640
641        // Handle both string and array content formats
642        let content = match value.get("content") {
643            Some(content) => match content {
644                serde_json::Value::String(s) => s.clone(),
645                serde_json::Value::Array(arr) => arr
646                    .iter()
647                    .filter_map(|c| {
648                        c.get("text")
649                            .and_then(|t| t.as_str())
650                            .map(|text| text.to_string())
651                    })
652                    .collect::<Vec<_>>()
653                    .join("\n"),
654                _ => {
655                    return Err(CompletionError::ResponseError(
656                        "Message content must be string or array".to_owned(),
657                    ));
658                }
659            },
660            None => {
661                return Err(CompletionError::ResponseError(
662                    "Message missing content field".to_owned(),
663                ));
664            }
665        };
666
667        match role {
668            "user" => Ok(Message::User {
669                content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
670            }),
671            "assistant" => Ok(Message::Assistant {
672                id: None,
673                content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
674            }),
675            _ => Err(CompletionError::ResponseError(format!(
676                "Unsupported message role: {role}"
677            ))),
678        }
679    }
680}
681
682#[cfg(test)]
683mod tests {
684    use super::*;
685    use crate::message::UserContent;
686    use serde_json::json;
687
688    #[test]
689    fn test_deserialize_message() {
690        // Test string content format
691        let assistant_message_json = json!({
692            "role": "assistant",
693            "content": "Hello there, how may I assist you today?"
694        });
695
696        let user_message_json = json!({
697            "role": "user",
698            "content": "What can you help me with?"
699        });
700
701        // Test array content format
702        let assistant_message_array_json = json!({
703            "role": "assistant",
704            "content": [{
705                "type": "text",
706                "text": "Hello there, how may I assist you today?"
707            }]
708        });
709
710        let assistant_message = Message::try_from(assistant_message_json).unwrap();
711        let user_message = Message::try_from(user_message_json).unwrap();
712        let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
713
714        // Test string content format
715        match assistant_message {
716            Message::Assistant { content, .. } => {
717                assert_eq!(
718                    content.first(),
719                    AssistantContent::Text(message::Text {
720                        text: "Hello there, how may I assist you today?".to_string()
721                    })
722                );
723            }
724            _ => panic!("Expected assistant message"),
725        }
726
727        match user_message {
728            Message::User { content } => {
729                assert_eq!(
730                    content.first(),
731                    UserContent::Text(message::Text {
732                        text: "What can you help me with?".to_string()
733                    })
734                );
735            }
736            _ => panic!("Expected user message"),
737        }
738
739        // Test array content format
740        match assistant_message_array {
741            Message::Assistant { content, .. } => {
742                assert_eq!(
743                    content.first(),
744                    AssistantContent::Text(message::Text {
745                        text: "Hello there, how may I assist you today?".to_string()
746                    })
747                );
748            }
749            _ => panic!("Expected assistant message"),
750        }
751    }
752
753    #[test]
754    fn test_message_conversion() {
755        // Test converting from our Message type to Mira's format and back
756        let original_message = message::Message::User {
757            content: OneOrMany::one(message::UserContent::text("Hello")),
758        };
759
760        // Convert to Mira format
761        let mira_value: serde_json::Value = original_message.clone().into();
762
763        // Convert back to our Message type
764        let converted_message: Message = mira_value.try_into().unwrap();
765
766        assert_eq!(original_message, converted_message);
767    }
768
769    #[test]
770    fn test_completion_response_conversion() {
771        let mira_response = CompletionResponse::Structured {
772            id: "resp_123".to_string(),
773            object: "chat.completion".to_string(),
774            created: 1234567890,
775            model: "deepseek-r1".to_string(),
776            choices: vec![ChatChoice {
777                message: RawMessage {
778                    role: "assistant".to_string(),
779                    content: "Test response".to_string(),
780                },
781                finish_reason: Some("stop".to_string()),
782                index: Some(0),
783            }],
784            usage: Some(Usage {
785                prompt_tokens: 10,
786                total_tokens: 20,
787            }),
788        };
789
790        let completion_response: completion::CompletionResponse<CompletionResponse> =
791            mira_response.try_into().unwrap();
792
793        assert_eq!(
794            completion_response.choice.first(),
795            completion::AssistantContent::text("Test response")
796        );
797    }
798    #[test]
799    fn test_client_initialization() {
800        let _client =
801            crate::providers::mira::Client::new("dummy-key").expect("Client::new() failed");
802        let _client_from_builder = crate::providers::mira::Client::builder()
803            .api_key("dummy-key")
804            .build()
805            .expect("Client::builder() failed");
806    }
807}