1use crate::client::{
11 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
12 ProviderClient,
13};
14use crate::http_client::{self, HttpClientExt};
15use crate::message::{Document, DocumentSourceKind};
16use crate::providers::openai;
17use crate::providers::openai::send_compatible_streaming_request;
18use crate::streaming::StreamingCompletionResponse;
19use crate::{
20 OneOrMany,
21 completion::{self, CompletionError, CompletionRequest},
22 message::{self, AssistantContent, Message, UserContent},
23};
24use serde::{Deserialize, Serialize};
25use std::string::FromUtf8Error;
26use thiserror::Error;
27use tracing::{self, Instrument, info_span};
28
29#[derive(Debug, Default, Clone, Copy)]
30pub struct MiraExt;
31#[derive(Debug, Default, Clone, Copy)]
32pub struct MiraBuilder;
33
34type MiraApiKey = BearerAuth;
35
36impl Provider for MiraExt {
37 type Builder = MiraBuilder;
38
39 const VERIFY_PATH: &'static str = "/user-credits";
40}
41
42impl<H> Capabilities<H> for MiraExt {
43 type Completion = Capable<CompletionModel<H>>;
44 type Embeddings = Nothing;
45 type Transcription = Nothing;
46 type ModelListing = Nothing;
47
48 #[cfg(feature = "image")]
49 type ImageGeneration = Nothing;
50
51 #[cfg(feature = "audio")]
52 type AudioGeneration = Nothing;
53}
54
55impl DebugExt for MiraExt {}
56
57impl ProviderBuilder for MiraBuilder {
58 type Extension<H>
59 = MiraExt
60 where
61 H: HttpClientExt;
62 type ApiKey = MiraApiKey;
63
64 const BASE_URL: &'static str = MIRA_API_BASE_URL;
65
66 fn build<H>(
67 _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
68 ) -> http_client::Result<Self::Extension<H>>
69 where
70 H: HttpClientExt,
71 {
72 Ok(MiraExt)
73 }
74}
75
76pub type Client<H = reqwest::Client> = client::Client<MiraExt, H>;
77pub type ClientBuilder<H = crate::markers::Missing> =
78 client::ClientBuilder<MiraBuilder, MiraApiKey, H>;
79
80#[derive(Debug, Error)]
81pub enum MiraError {
82 #[error("Invalid API key")]
83 InvalidApiKey,
84 #[error("API error: {0}")]
85 ApiError(u16),
86 #[error("Request error: {0}")]
87 RequestError(#[from] http_client::Error),
88 #[error("UTF-8 error: {0}")]
89 Utf8Error(#[from] FromUtf8Error),
90 #[error("JSON error: {0}")]
91 JsonError(#[from] serde_json::Error),
92}
93
94#[derive(Debug, Deserialize)]
95struct ApiErrorResponse {
96 message: String,
97}
98
99#[derive(Debug, Deserialize, Clone, Serialize)]
100pub struct RawMessage {
101 pub role: String,
102 pub content: String,
103}
104
105const MIRA_API_BASE_URL: &str = "https://api.mira.network";
106
107impl TryFrom<RawMessage> for message::Message {
108 type Error = CompletionError;
109
110 fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
111 match raw.role.as_str() {
112 "system" => Ok(message::Message::System {
113 content: raw.content,
114 }),
115 "user" => Ok(message::Message::User {
116 content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
117 }),
118 "assistant" => Ok(message::Message::Assistant {
119 id: None,
120 content: OneOrMany::one(AssistantContent::Text(message::Text {
121 text: raw.content,
122 })),
123 }),
124 _ => Err(CompletionError::ResponseError(format!(
125 "Unsupported message role: {}",
126 raw.role
127 ))),
128 }
129 }
130}
131
132#[derive(Debug, Deserialize, Serialize)]
133#[serde(untagged)]
134pub enum CompletionResponse {
135 Structured {
136 id: String,
137 object: String,
138 created: u64,
139 model: String,
140 choices: Vec<ChatChoice>,
141 #[serde(skip_serializing_if = "Option::is_none")]
142 usage: Option<Usage>,
143 },
144 Simple(String),
145}
146
147#[derive(Debug, Deserialize, Serialize)]
148pub struct ChatChoice {
149 pub message: RawMessage,
150 #[serde(default)]
151 pub finish_reason: Option<String>,
152 #[serde(default)]
153 pub index: Option<usize>,
154}
155
156#[derive(Debug, Deserialize, Serialize)]
157struct ModelsResponse {
158 data: Vec<ModelInfo>,
159}
160
161#[derive(Debug, Deserialize, Serialize)]
162struct ModelInfo {
163 id: String,
164}
165
166impl<T> Client<T>
167where
168 T: HttpClientExt + 'static,
169{
170 pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
172 let req = self.get("/v1/models").and_then(|req| {
173 req.body(http_client::NoBody)
174 .map_err(http_client::Error::Protocol)
175 })?;
176
177 let response = self.send(req).await?;
178
179 let status = response.status();
180
181 if !status.is_success() {
182 let error_text = http_client::text(response).await.unwrap_or_default();
184 tracing::error!("Error response: {}", error_text);
185 return Err(MiraError::ApiError(status.as_u16()));
186 }
187
188 let response_text = http_client::text(response).await?;
189
190 let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
191 tracing::error!("Failed to parse response: {}", e);
192 MiraError::JsonError(e)
193 })?;
194
195 Ok(models.data.into_iter().map(|model| model.id).collect())
196 }
197}
198
199impl ProviderClient for Client {
200 type Input = String;
201 type Error = crate::client::ProviderClientError;
202
203 fn from_env() -> Result<Self, Self::Error> {
205 let api_key = crate::client::required_env_var("MIRA_API_KEY")?;
206 Self::new(&api_key).map_err(Into::into)
207 }
208
209 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
210 Self::new(&input).map_err(Into::into)
211 }
212}
213
214#[derive(Debug, Serialize, Deserialize)]
215pub(super) struct MiraCompletionRequest {
216 model: String,
217 pub messages: Vec<RawMessage>,
218 #[serde(skip_serializing_if = "Option::is_none")]
219 temperature: Option<f64>,
220 #[serde(skip_serializing_if = "Option::is_none")]
221 max_tokens: Option<u64>,
222 pub stream: bool,
223}
224
225impl TryFrom<(&str, CompletionRequest)> for MiraCompletionRequest {
226 type Error = CompletionError;
227
228 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
229 if req.output_schema.is_some() {
230 tracing::warn!("Structured outputs currently not supported for Mira");
231 }
232 let model = req.model.clone().unwrap_or_else(|| model.to_string());
233 let mut messages = Vec::new();
234
235 if let Some(content) = &req.preamble {
236 messages.push(RawMessage {
237 role: "user".to_string(),
238 content: content.to_string(),
239 });
240 }
241
242 if let Some(Message::User { content }) = req.normalized_documents() {
243 let text = content
244 .into_iter()
245 .filter_map(|doc| match doc {
246 UserContent::Document(Document {
247 data: DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data),
248 ..
249 }) => Some(data),
250 UserContent::Text(text) => Some(text.text),
251
252 _ => None,
254 })
255 .collect::<Vec<_>>()
256 .join("\n");
257
258 messages.push(RawMessage {
259 role: "user".to_string(),
260 content: text,
261 });
262 }
263
264 for msg in req.chat_history {
265 let (role, content) = match msg {
266 Message::System { content } => ("system", content),
267 Message::User { content } => {
268 let text = content
269 .iter()
270 .map(|c| match c {
271 UserContent::Text(text) => &text.text,
272 _ => "",
273 })
274 .collect::<Vec<_>>()
275 .join("\n");
276 ("user", text)
277 }
278 Message::Assistant { content, .. } => {
279 let text = content
280 .iter()
281 .map(|c| match c {
282 AssistantContent::Text(text) => &text.text,
283 _ => "",
284 })
285 .collect::<Vec<_>>()
286 .join("\n");
287 ("assistant", text)
288 }
289 };
290 messages.push(RawMessage {
291 role: role.to_string(),
292 content,
293 });
294 }
295
296 Ok(Self {
297 model: model.to_string(),
298 messages,
299 temperature: req.temperature,
300 max_tokens: req.max_tokens,
301 stream: false,
302 })
303 }
304}
305
306#[derive(Clone)]
307pub struct CompletionModel<T = reqwest::Client> {
308 client: Client<T>,
309 pub model: String,
311}
312
313impl<T> CompletionModel<T> {
314 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
315 Self {
316 client,
317 model: model.into(),
318 }
319 }
320}
321
322impl<T> completion::CompletionModel for CompletionModel<T>
323where
324 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
325{
326 type Response = CompletionResponse;
327 type StreamingResponse = openai::StreamingCompletionResponse;
328
329 type Client = Client<T>;
330
331 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
332 Self::new(client.clone(), model)
333 }
334
335 async fn completion(
336 &self,
337 completion_request: CompletionRequest,
338 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
339 let span = if tracing::Span::current().is_disabled() {
340 info_span!(
341 target: "rig::completions",
342 "chat",
343 gen_ai.operation.name = "chat",
344 gen_ai.provider.name = "mira",
345 gen_ai.request.model = self.model,
346 gen_ai.system_instructions = tracing::field::Empty,
347 gen_ai.response.id = tracing::field::Empty,
348 gen_ai.response.model = tracing::field::Empty,
349 gen_ai.usage.output_tokens = tracing::field::Empty,
350 gen_ai.usage.input_tokens = tracing::field::Empty,
351 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
352 )
353 } else {
354 tracing::Span::current()
355 };
356
357 span.record("gen_ai.system_instructions", &completion_request.preamble);
358
359 if !completion_request.tools.is_empty() {
360 tracing::warn!(target: "rig::completions",
361 "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
362 len = completion_request.tools.len()
363 );
364 }
365
366 if completion_request.tool_choice.is_some() {
367 tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
368 }
369
370 if completion_request.additional_params.is_some() {
371 tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
372 }
373
374 let request = MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
375
376 if tracing::enabled!(tracing::Level::TRACE) {
377 tracing::trace!(target: "rig::completions",
378 "Mira completion request: {}",
379 serde_json::to_string_pretty(&request)?
380 );
381 }
382
383 let body = serde_json::to_vec(&request)?;
384
385 let req = self
386 .client
387 .post("/v1/chat/completions")?
388 .body(body)
389 .map_err(http_client::Error::from)?;
390
391 let async_block = async move {
392 let response = self
393 .client
394 .send::<_, bytes::Bytes>(req)
395 .await
396 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
397
398 let status = response.status();
399 let response_body = response.into_body().into_future().await?.to_vec();
400
401 if !status.is_success() {
402 let status = status.as_u16();
403 let error_text = String::from_utf8_lossy(&response_body).to_string();
404 return Err(CompletionError::ProviderError(format!(
405 "API error: {status} - {error_text}"
406 )));
407 }
408
409 let response: CompletionResponse = serde_json::from_slice(&response_body)?;
410
411 if tracing::enabled!(tracing::Level::TRACE) {
412 tracing::trace!(target: "rig::completions",
413 "Mira completion response: {}",
414 serde_json::to_string_pretty(&response)?
415 );
416 }
417
418 if let CompletionResponse::Structured {
419 id, model, usage, ..
420 } = &response
421 {
422 let span = tracing::Span::current();
423 span.record("gen_ai.response.model", model);
424 span.record("gen_ai.response.id", id);
425 if let Some(usage) = usage {
426 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
427 span.record(
428 "gen_ai.usage.output_tokens",
429 usage.total_tokens - usage.prompt_tokens,
430 );
431 }
432 }
433
434 response.try_into()
435 };
436
437 async_block.instrument(span).await
438 }
439
440 async fn stream(
441 &self,
442 completion_request: CompletionRequest,
443 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
444 let span = if tracing::Span::current().is_disabled() {
445 info_span!(
446 target: "rig::completions",
447 "chat_streaming",
448 gen_ai.operation.name = "chat_streaming",
449 gen_ai.provider.name = "mira",
450 gen_ai.request.model = self.model,
451 gen_ai.system_instructions = tracing::field::Empty,
452 gen_ai.response.id = tracing::field::Empty,
453 gen_ai.response.model = tracing::field::Empty,
454 gen_ai.usage.output_tokens = tracing::field::Empty,
455 gen_ai.usage.input_tokens = tracing::field::Empty,
456 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
457 )
458 } else {
459 tracing::Span::current()
460 };
461
462 span.record("gen_ai.system_instructions", &completion_request.preamble);
463
464 if !completion_request.tools.is_empty() {
465 tracing::warn!(target: "rig::completions",
466 "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
467 len = completion_request.tools.len()
468 );
469 }
470
471 if completion_request.tool_choice.is_some() {
472 tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
473 }
474
475 if completion_request.additional_params.is_some() {
476 tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
477 }
478 let mut request =
479 MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
480 request.stream = true;
481
482 if tracing::enabled!(tracing::Level::TRACE) {
483 tracing::trace!(target: "rig::completions",
484 "Mira completion request: {}",
485 serde_json::to_string_pretty(&request)?
486 );
487 }
488
489 let body = serde_json::to_vec(&request)?;
490
491 let req = self
492 .client
493 .post("/v1/chat/completions")?
494 .body(body)
495 .map_err(http_client::Error::from)?;
496
497 send_compatible_streaming_request(self.client.clone(), req)
498 .instrument(span)
499 .await
500 }
501}
502
503impl From<ApiErrorResponse> for CompletionError {
504 fn from(err: ApiErrorResponse) -> Self {
505 CompletionError::ProviderError(err.message)
506 }
507}
508
509impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
510 type Error = CompletionError;
511
512 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
513 let (content, usage) = match &response {
514 CompletionResponse::Structured { choices, usage, .. } => {
515 let choice = choices.first().ok_or_else(|| {
516 CompletionError::ResponseError("Response contained no choices".to_owned())
517 })?;
518
519 let usage = usage
520 .as_ref()
521 .map(|usage| completion::Usage {
522 input_tokens: usage.prompt_tokens as u64,
523 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
524 total_tokens: usage.total_tokens as u64,
525 cached_input_tokens: 0,
526 cache_creation_input_tokens: 0,
527 reasoning_tokens: 0,
528 })
529 .unwrap_or_default();
530
531 let message = message::Message::try_from(choice.message.clone())?;
533
534 let content = match message {
535 Message::Assistant { content, .. } => {
536 if content.is_empty() {
537 return Err(CompletionError::ResponseError(
538 "Response contained empty content".to_owned(),
539 ));
540 }
541
542 for c in content.iter() {
544 if !matches!(c, AssistantContent::Text(_)) {
545 tracing::warn!(target: "rig",
546 "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
547 );
548 }
549 }
550
551 content.iter().map(|c| {
552 match c {
553 AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
554 other => Err(CompletionError::ResponseError(
555 format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
556 ))
557 }
558 }).collect::<Result<Vec<_>, _>>()?
559 }
560 Message::User { .. } => {
561 tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
562 return Err(CompletionError::ResponseError(
563 "Received user message in response where assistant message was expected".to_owned()
564 ));
565 }
566 Message::System { .. } => {
567 tracing::warn!(target: "rig", "Received system message in response where assistant message was expected");
568 return Err(CompletionError::ResponseError(
569 "Received system message in response where assistant message was expected".to_owned(),
570 ));
571 }
572 };
573
574 (content, usage)
575 }
576 CompletionResponse::Simple(text) => (
577 vec![completion::AssistantContent::text(text)],
578 completion::Usage::new(),
579 ),
580 };
581
582 let choice = OneOrMany::many(content).map_err(|_| {
583 CompletionError::ResponseError(
584 "Response contained no message or tool call (empty)".to_owned(),
585 )
586 })?;
587
588 Ok(completion::CompletionResponse {
589 choice,
590 usage,
591 raw_response: response,
592 message_id: None,
593 })
594 }
595}
596
597#[derive(Clone, Debug, Deserialize, Serialize)]
598pub struct Usage {
599 pub prompt_tokens: usize,
600 pub total_tokens: usize,
601}
602
603impl std::fmt::Display for Usage {
604 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
605 write!(
606 f,
607 "Prompt tokens: {} Total tokens: {}",
608 self.prompt_tokens, self.total_tokens
609 )
610 }
611}
612
613impl From<Message> for serde_json::Value {
614 fn from(msg: Message) -> Self {
615 match msg {
616 Message::System { content } => serde_json::json!({
617 "role": "system",
618 "content": content
619 }),
620 Message::User { content } => {
621 let text = content
622 .iter()
623 .map(|c| match c {
624 UserContent::Text(text) => &text.text,
625 _ => "",
626 })
627 .collect::<Vec<_>>()
628 .join("\n");
629 serde_json::json!({
630 "role": "user",
631 "content": text
632 })
633 }
634 Message::Assistant { content, .. } => {
635 let text = content
636 .iter()
637 .map(|c| match c {
638 AssistantContent::Text(text) => &text.text,
639 _ => "",
640 })
641 .collect::<Vec<_>>()
642 .join("\n");
643 serde_json::json!({
644 "role": "assistant",
645 "content": text
646 })
647 }
648 }
649 }
650}
651
652impl TryFrom<serde_json::Value> for Message {
653 type Error = CompletionError;
654
655 fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
656 let role = value
657 .get("role")
658 .and_then(serde_json::Value::as_str)
659 .ok_or_else(|| {
660 CompletionError::ResponseError("Message missing role field".to_owned())
661 })?;
662
663 let content = match value.get("content") {
665 Some(content) => match content {
666 serde_json::Value::String(s) => s.clone(),
667 serde_json::Value::Array(arr) => arr
668 .iter()
669 .filter_map(|c| {
670 c.get("text")
671 .and_then(|t| t.as_str())
672 .map(|text| text.to_string())
673 })
674 .collect::<Vec<_>>()
675 .join("\n"),
676 _ => {
677 return Err(CompletionError::ResponseError(
678 "Message content must be string or array".to_owned(),
679 ));
680 }
681 },
682 None => {
683 return Err(CompletionError::ResponseError(
684 "Message missing content field".to_owned(),
685 ));
686 }
687 };
688
689 match role {
690 "system" => Ok(Message::System { content }),
691 "user" => Ok(Message::User {
692 content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
693 }),
694 "assistant" => Ok(Message::Assistant {
695 id: None,
696 content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
697 }),
698 _ => Err(CompletionError::ResponseError(format!(
699 "Unsupported message role: {role}"
700 ))),
701 }
702 }
703}
704
705#[cfg(test)]
706mod tests {
707 use super::*;
708 use crate::message::UserContent;
709 use serde_json::json;
710
711 #[test]
712 fn test_deserialize_message() {
713 let assistant_message_json = json!({
715 "role": "assistant",
716 "content": "Hello there, how may I assist you today?"
717 });
718
719 let user_message_json = json!({
720 "role": "user",
721 "content": "What can you help me with?"
722 });
723
724 let assistant_message_array_json = json!({
726 "role": "assistant",
727 "content": [{
728 "type": "text",
729 "text": "Hello there, how may I assist you today?"
730 }]
731 });
732
733 let assistant_message = Message::try_from(assistant_message_json).unwrap();
734 let user_message = Message::try_from(user_message_json).unwrap();
735 let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
736
737 match assistant_message {
739 Message::Assistant { content, .. } => {
740 assert_eq!(
741 content.first(),
742 AssistantContent::Text(message::Text {
743 text: "Hello there, how may I assist you today?".to_string()
744 })
745 );
746 }
747 _ => panic!("Expected assistant message"),
748 }
749
750 match user_message {
751 Message::User { content } => {
752 assert_eq!(
753 content.first(),
754 UserContent::Text(message::Text {
755 text: "What can you help me with?".to_string()
756 })
757 );
758 }
759 _ => panic!("Expected user message"),
760 }
761
762 match assistant_message_array {
764 Message::Assistant { content, .. } => {
765 assert_eq!(
766 content.first(),
767 AssistantContent::Text(message::Text {
768 text: "Hello there, how may I assist you today?".to_string()
769 })
770 );
771 }
772 _ => panic!("Expected assistant message"),
773 }
774 }
775
776 #[test]
777 fn test_message_conversion() {
778 let original_message = message::Message::User {
780 content: OneOrMany::one(message::UserContent::text("Hello")),
781 };
782
783 let mira_value: serde_json::Value = original_message.clone().into();
785
786 let converted_message: Message = mira_value.try_into().unwrap();
788
789 assert_eq!(original_message, converted_message);
790 }
791
792 #[test]
793 fn test_completion_response_conversion() {
794 let mira_response = CompletionResponse::Structured {
795 id: "resp_123".to_string(),
796 object: "chat.completion".to_string(),
797 created: 1234567890,
798 model: "deepseek-r1".to_string(),
799 choices: vec![ChatChoice {
800 message: RawMessage {
801 role: "assistant".to_string(),
802 content: "Test response".to_string(),
803 },
804 finish_reason: Some("stop".to_string()),
805 index: Some(0),
806 }],
807 usage: Some(Usage {
808 prompt_tokens: 10,
809 total_tokens: 20,
810 }),
811 };
812
813 let completion_response: completion::CompletionResponse<CompletionResponse> =
814 mira_response.try_into().unwrap();
815
816 assert_eq!(
817 completion_response.choice.first(),
818 completion::AssistantContent::text("Test response")
819 );
820 }
821 #[test]
822 fn test_client_initialization() {
823 let _client =
824 crate::providers::mira::Client::new("dummy-key").expect("Client::new() failed");
825 let _client_from_builder = crate::providers::mira::Client::builder()
826 .api_key("dummy-key")
827 .build()
828 .expect("Client::builder() failed");
829 }
830}