1use crate::client::{
11 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
12 ProviderClient,
13};
14use crate::http_client::{self, HttpClientExt};
15use crate::message::{Document, DocumentSourceKind};
16use crate::providers::openai;
17use crate::providers::openai::send_compatible_streaming_request;
18use crate::streaming::StreamingCompletionResponse;
19use crate::{
20 OneOrMany,
21 completion::{self, CompletionError, CompletionRequest},
22 message::{self, AssistantContent, Message, UserContent},
23};
24use serde::{Deserialize, Serialize};
25use std::string::FromUtf8Error;
26use thiserror::Error;
27use tracing::{self, Instrument, info_span};
28
29#[derive(Debug, Default, Clone, Copy)]
30pub struct MiraExt;
31#[derive(Debug, Default, Clone, Copy)]
32pub struct MiraBuilder;
33
34type MiraApiKey = BearerAuth;
35
36impl Provider for MiraExt {
37 type Builder = MiraBuilder;
38
39 const VERIFY_PATH: &'static str = "/user-credits";
40}
41
42impl<H> Capabilities<H> for MiraExt {
43 type Completion = Capable<CompletionModel<H>>;
44 type Embeddings = Nothing;
45 type Transcription = Nothing;
46 type ModelListing = Nothing;
47
48 #[cfg(feature = "image")]
49 type ImageGeneration = Nothing;
50
51 #[cfg(feature = "audio")]
52 type AudioGeneration = Nothing;
53}
54
55impl DebugExt for MiraExt {}
56
57impl ProviderBuilder for MiraBuilder {
58 type Extension<H>
59 = MiraExt
60 where
61 H: HttpClientExt;
62 type ApiKey = MiraApiKey;
63
64 const BASE_URL: &'static str = MIRA_API_BASE_URL;
65
66 fn build<H>(
67 _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
68 ) -> http_client::Result<Self::Extension<H>>
69 where
70 H: HttpClientExt,
71 {
72 Ok(MiraExt)
73 }
74}
75
76pub type Client<H = reqwest::Client> = client::Client<MiraExt, H>;
77pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<MiraBuilder, MiraApiKey, H>;
78
79#[derive(Debug, Error)]
80pub enum MiraError {
81 #[error("Invalid API key")]
82 InvalidApiKey,
83 #[error("API error: {0}")]
84 ApiError(u16),
85 #[error("Request error: {0}")]
86 RequestError(#[from] http_client::Error),
87 #[error("UTF-8 error: {0}")]
88 Utf8Error(#[from] FromUtf8Error),
89 #[error("JSON error: {0}")]
90 JsonError(#[from] serde_json::Error),
91}
92
93#[derive(Debug, Deserialize)]
94struct ApiErrorResponse {
95 message: String,
96}
97
98#[derive(Debug, Deserialize, Clone, Serialize)]
99pub struct RawMessage {
100 pub role: String,
101 pub content: String,
102}
103
104const MIRA_API_BASE_URL: &str = "https://api.mira.network";
105
106impl TryFrom<RawMessage> for message::Message {
107 type Error = CompletionError;
108
109 fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
110 match raw.role.as_str() {
111 "system" => Ok(message::Message::System {
112 content: raw.content,
113 }),
114 "user" => Ok(message::Message::User {
115 content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
116 }),
117 "assistant" => Ok(message::Message::Assistant {
118 id: None,
119 content: OneOrMany::one(AssistantContent::Text(message::Text {
120 text: raw.content,
121 })),
122 }),
123 _ => Err(CompletionError::ResponseError(format!(
124 "Unsupported message role: {}",
125 raw.role
126 ))),
127 }
128 }
129}
130
131#[derive(Debug, Deserialize, Serialize)]
132#[serde(untagged)]
133pub enum CompletionResponse {
134 Structured {
135 id: String,
136 object: String,
137 created: u64,
138 model: String,
139 choices: Vec<ChatChoice>,
140 #[serde(skip_serializing_if = "Option::is_none")]
141 usage: Option<Usage>,
142 },
143 Simple(String),
144}
145
146#[derive(Debug, Deserialize, Serialize)]
147pub struct ChatChoice {
148 pub message: RawMessage,
149 #[serde(default)]
150 pub finish_reason: Option<String>,
151 #[serde(default)]
152 pub index: Option<usize>,
153}
154
155#[derive(Debug, Deserialize, Serialize)]
156struct ModelsResponse {
157 data: Vec<ModelInfo>,
158}
159
160#[derive(Debug, Deserialize, Serialize)]
161struct ModelInfo {
162 id: String,
163}
164
165impl<T> Client<T>
166where
167 T: HttpClientExt + 'static,
168{
169 pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
171 let req = self.get("/v1/models").and_then(|req| {
172 req.body(http_client::NoBody)
173 .map_err(http_client::Error::Protocol)
174 })?;
175
176 let response = self.send(req).await?;
177
178 let status = response.status();
179
180 if !status.is_success() {
181 let error_text = http_client::text(response).await.unwrap_or_default();
183 tracing::error!("Error response: {}", error_text);
184 return Err(MiraError::ApiError(status.as_u16()));
185 }
186
187 let response_text = http_client::text(response).await?;
188
189 let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
190 tracing::error!("Failed to parse response: {}", e);
191 MiraError::JsonError(e)
192 })?;
193
194 Ok(models.data.into_iter().map(|model| model.id).collect())
195 }
196}
197
198impl ProviderClient for Client {
199 type Input = String;
200
201 fn from_env() -> Self {
204 let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
205 Self::new(&api_key).unwrap()
206 }
207
208 fn from_val(input: Self::Input) -> Self {
209 Self::new(&input).unwrap()
210 }
211}
212
213#[derive(Debug, Serialize, Deserialize)]
214pub(super) struct MiraCompletionRequest {
215 model: String,
216 pub messages: Vec<RawMessage>,
217 #[serde(skip_serializing_if = "Option::is_none")]
218 temperature: Option<f64>,
219 #[serde(skip_serializing_if = "Option::is_none")]
220 max_tokens: Option<u64>,
221 pub stream: bool,
222}
223
224impl TryFrom<(&str, CompletionRequest)> for MiraCompletionRequest {
225 type Error = CompletionError;
226
227 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
228 if req.output_schema.is_some() {
229 tracing::warn!("Structured outputs currently not supported for Mira");
230 }
231 let model = req.model.clone().unwrap_or_else(|| model.to_string());
232 let mut messages = Vec::new();
233
234 if let Some(content) = &req.preamble {
235 messages.push(RawMessage {
236 role: "user".to_string(),
237 content: content.to_string(),
238 });
239 }
240
241 if let Some(Message::User { content }) = req.normalized_documents() {
242 let text = content
243 .into_iter()
244 .filter_map(|doc| match doc {
245 UserContent::Document(Document {
246 data: DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data),
247 ..
248 }) => Some(data),
249 UserContent::Text(text) => Some(text.text),
250
251 _ => None,
253 })
254 .collect::<Vec<_>>()
255 .join("\n");
256
257 messages.push(RawMessage {
258 role: "user".to_string(),
259 content: text,
260 });
261 }
262
263 for msg in req.chat_history {
264 let (role, content) = match msg {
265 Message::System { content } => ("system", content),
266 Message::User { content } => {
267 let text = content
268 .iter()
269 .map(|c| match c {
270 UserContent::Text(text) => &text.text,
271 _ => "",
272 })
273 .collect::<Vec<_>>()
274 .join("\n");
275 ("user", text)
276 }
277 Message::Assistant { content, .. } => {
278 let text = content
279 .iter()
280 .map(|c| match c {
281 AssistantContent::Text(text) => &text.text,
282 _ => "",
283 })
284 .collect::<Vec<_>>()
285 .join("\n");
286 ("assistant", text)
287 }
288 };
289 messages.push(RawMessage {
290 role: role.to_string(),
291 content,
292 });
293 }
294
295 Ok(Self {
296 model: model.to_string(),
297 messages,
298 temperature: req.temperature,
299 max_tokens: req.max_tokens,
300 stream: false,
301 })
302 }
303}
304
305#[derive(Clone)]
306pub struct CompletionModel<T = reqwest::Client> {
307 client: Client<T>,
308 pub model: String,
310}
311
312impl<T> CompletionModel<T> {
313 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
314 Self {
315 client,
316 model: model.into(),
317 }
318 }
319}
320
321impl<T> completion::CompletionModel for CompletionModel<T>
322where
323 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
324{
325 type Response = CompletionResponse;
326 type StreamingResponse = openai::StreamingCompletionResponse;
327
328 type Client = Client<T>;
329
330 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
331 Self::new(client.clone(), model)
332 }
333
334 async fn completion(
335 &self,
336 completion_request: CompletionRequest,
337 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
338 let span = if tracing::Span::current().is_disabled() {
339 info_span!(
340 target: "rig::completions",
341 "chat",
342 gen_ai.operation.name = "chat",
343 gen_ai.provider.name = "mira",
344 gen_ai.request.model = self.model,
345 gen_ai.system_instructions = tracing::field::Empty,
346 gen_ai.response.id = tracing::field::Empty,
347 gen_ai.response.model = tracing::field::Empty,
348 gen_ai.usage.output_tokens = tracing::field::Empty,
349 gen_ai.usage.input_tokens = tracing::field::Empty,
350 gen_ai.usage.cached_tokens = tracing::field::Empty,
351 )
352 } else {
353 tracing::Span::current()
354 };
355
356 span.record("gen_ai.system_instructions", &completion_request.preamble);
357
358 if !completion_request.tools.is_empty() {
359 tracing::warn!(target: "rig::completions",
360 "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
361 len = completion_request.tools.len()
362 );
363 }
364
365 if completion_request.tool_choice.is_some() {
366 tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
367 }
368
369 if completion_request.additional_params.is_some() {
370 tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
371 }
372
373 let request = MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
374
375 if tracing::enabled!(tracing::Level::TRACE) {
376 tracing::trace!(target: "rig::completions",
377 "Mira completion request: {}",
378 serde_json::to_string_pretty(&request)?
379 );
380 }
381
382 let body = serde_json::to_vec(&request)?;
383
384 let req = self
385 .client
386 .post("/v1/chat/completions")?
387 .body(body)
388 .map_err(http_client::Error::from)?;
389
390 let async_block = async move {
391 let response = self
392 .client
393 .send::<_, bytes::Bytes>(req)
394 .await
395 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
396
397 let status = response.status();
398 let response_body = response.into_body().into_future().await?.to_vec();
399
400 if !status.is_success() {
401 let status = status.as_u16();
402 let error_text = String::from_utf8_lossy(&response_body).to_string();
403 return Err(CompletionError::ProviderError(format!(
404 "API error: {status} - {error_text}"
405 )));
406 }
407
408 let response: CompletionResponse = serde_json::from_slice(&response_body)?;
409
410 if tracing::enabled!(tracing::Level::TRACE) {
411 tracing::trace!(target: "rig::completions",
412 "Mira completion response: {}",
413 serde_json::to_string_pretty(&response)?
414 );
415 }
416
417 if let CompletionResponse::Structured {
418 id, model, usage, ..
419 } = &response
420 {
421 let span = tracing::Span::current();
422 span.record("gen_ai.response.model_name", model);
423 span.record("gen_ai.response.id", id);
424 if let Some(usage) = usage {
425 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
426 span.record(
427 "gen_ai.usage.output_tokens",
428 usage.total_tokens - usage.prompt_tokens,
429 );
430 }
431 }
432
433 response.try_into()
434 };
435
436 async_block.instrument(span).await
437 }
438
439 async fn stream(
440 &self,
441 completion_request: CompletionRequest,
442 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
443 let span = if tracing::Span::current().is_disabled() {
444 info_span!(
445 target: "rig::completions",
446 "chat_streaming",
447 gen_ai.operation.name = "chat_streaming",
448 gen_ai.provider.name = "mira",
449 gen_ai.request.model = self.model,
450 gen_ai.system_instructions = tracing::field::Empty,
451 gen_ai.response.id = tracing::field::Empty,
452 gen_ai.response.model = tracing::field::Empty,
453 gen_ai.usage.output_tokens = tracing::field::Empty,
454 gen_ai.usage.input_tokens = tracing::field::Empty,
455 gen_ai.usage.cached_tokens = tracing::field::Empty,
456 )
457 } else {
458 tracing::Span::current()
459 };
460
461 span.record("gen_ai.system_instructions", &completion_request.preamble);
462
463 if !completion_request.tools.is_empty() {
464 tracing::warn!(target: "rig::completions",
465 "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
466 len = completion_request.tools.len()
467 );
468 }
469
470 if completion_request.tool_choice.is_some() {
471 tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
472 }
473
474 if completion_request.additional_params.is_some() {
475 tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
476 }
477 let mut request =
478 MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
479 request.stream = true;
480
481 if tracing::enabled!(tracing::Level::TRACE) {
482 tracing::trace!(target: "rig::completions",
483 "Mira completion request: {}",
484 serde_json::to_string_pretty(&request)?
485 );
486 }
487
488 let body = serde_json::to_vec(&request)?;
489
490 let req = self
491 .client
492 .post("/v1/chat/completions")?
493 .body(body)
494 .map_err(http_client::Error::from)?;
495
496 send_compatible_streaming_request(self.client.clone(), req)
497 .instrument(span)
498 .await
499 }
500}
501
502impl From<ApiErrorResponse> for CompletionError {
503 fn from(err: ApiErrorResponse) -> Self {
504 CompletionError::ProviderError(err.message)
505 }
506}
507
508impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
509 type Error = CompletionError;
510
511 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
512 let (content, usage) = match &response {
513 CompletionResponse::Structured { choices, usage, .. } => {
514 let choice = choices.first().ok_or_else(|| {
515 CompletionError::ResponseError("Response contained no choices".to_owned())
516 })?;
517
518 let usage = usage
519 .as_ref()
520 .map(|usage| completion::Usage {
521 input_tokens: usage.prompt_tokens as u64,
522 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
523 total_tokens: usage.total_tokens as u64,
524 cached_input_tokens: 0,
525 cache_creation_input_tokens: 0,
526 })
527 .unwrap_or_default();
528
529 let message = message::Message::try_from(choice.message.clone())?;
531
532 let content = match message {
533 Message::Assistant { content, .. } => {
534 if content.is_empty() {
535 return Err(CompletionError::ResponseError(
536 "Response contained empty content".to_owned(),
537 ));
538 }
539
540 for c in content.iter() {
542 if !matches!(c, AssistantContent::Text(_)) {
543 tracing::warn!(target: "rig",
544 "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
545 );
546 }
547 }
548
549 content.iter().map(|c| {
550 match c {
551 AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
552 other => Err(CompletionError::ResponseError(
553 format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
554 ))
555 }
556 }).collect::<Result<Vec<_>, _>>()?
557 }
558 Message::User { .. } => {
559 tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
560 return Err(CompletionError::ResponseError(
561 "Received user message in response where assistant message was expected".to_owned()
562 ));
563 }
564 Message::System { .. } => {
565 tracing::warn!(target: "rig", "Received system message in response where assistant message was expected");
566 return Err(CompletionError::ResponseError(
567 "Received system message in response where assistant message was expected".to_owned(),
568 ));
569 }
570 };
571
572 (content, usage)
573 }
574 CompletionResponse::Simple(text) => (
575 vec![completion::AssistantContent::text(text)],
576 completion::Usage::new(),
577 ),
578 };
579
580 let choice = OneOrMany::many(content).map_err(|_| {
581 CompletionError::ResponseError(
582 "Response contained no message or tool call (empty)".to_owned(),
583 )
584 })?;
585
586 Ok(completion::CompletionResponse {
587 choice,
588 usage,
589 raw_response: response,
590 message_id: None,
591 })
592 }
593}
594
595#[derive(Clone, Debug, Deserialize, Serialize)]
596pub struct Usage {
597 pub prompt_tokens: usize,
598 pub total_tokens: usize,
599}
600
601impl std::fmt::Display for Usage {
602 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
603 write!(
604 f,
605 "Prompt tokens: {} Total tokens: {}",
606 self.prompt_tokens, self.total_tokens
607 )
608 }
609}
610
611impl From<Message> for serde_json::Value {
612 fn from(msg: Message) -> Self {
613 match msg {
614 Message::System { content } => serde_json::json!({
615 "role": "system",
616 "content": content
617 }),
618 Message::User { content } => {
619 let text = content
620 .iter()
621 .map(|c| match c {
622 UserContent::Text(text) => &text.text,
623 _ => "",
624 })
625 .collect::<Vec<_>>()
626 .join("\n");
627 serde_json::json!({
628 "role": "user",
629 "content": text
630 })
631 }
632 Message::Assistant { content, .. } => {
633 let text = content
634 .iter()
635 .map(|c| match c {
636 AssistantContent::Text(text) => &text.text,
637 _ => "",
638 })
639 .collect::<Vec<_>>()
640 .join("\n");
641 serde_json::json!({
642 "role": "assistant",
643 "content": text
644 })
645 }
646 }
647 }
648}
649
650impl TryFrom<serde_json::Value> for Message {
651 type Error = CompletionError;
652
653 fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
654 let role = value["role"].as_str().ok_or_else(|| {
655 CompletionError::ResponseError("Message missing role field".to_owned())
656 })?;
657
658 let content = match value.get("content") {
660 Some(content) => match content {
661 serde_json::Value::String(s) => s.clone(),
662 serde_json::Value::Array(arr) => arr
663 .iter()
664 .filter_map(|c| {
665 c.get("text")
666 .and_then(|t| t.as_str())
667 .map(|text| text.to_string())
668 })
669 .collect::<Vec<_>>()
670 .join("\n"),
671 _ => {
672 return Err(CompletionError::ResponseError(
673 "Message content must be string or array".to_owned(),
674 ));
675 }
676 },
677 None => {
678 return Err(CompletionError::ResponseError(
679 "Message missing content field".to_owned(),
680 ));
681 }
682 };
683
684 match role {
685 "system" => Ok(Message::System { content }),
686 "user" => Ok(Message::User {
687 content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
688 }),
689 "assistant" => Ok(Message::Assistant {
690 id: None,
691 content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
692 }),
693 _ => Err(CompletionError::ResponseError(format!(
694 "Unsupported message role: {role}"
695 ))),
696 }
697 }
698}
699
700#[cfg(test)]
701mod tests {
702 use super::*;
703 use crate::message::UserContent;
704 use serde_json::json;
705
706 #[test]
707 fn test_deserialize_message() {
708 let assistant_message_json = json!({
710 "role": "assistant",
711 "content": "Hello there, how may I assist you today?"
712 });
713
714 let user_message_json = json!({
715 "role": "user",
716 "content": "What can you help me with?"
717 });
718
719 let assistant_message_array_json = json!({
721 "role": "assistant",
722 "content": [{
723 "type": "text",
724 "text": "Hello there, how may I assist you today?"
725 }]
726 });
727
728 let assistant_message = Message::try_from(assistant_message_json).unwrap();
729 let user_message = Message::try_from(user_message_json).unwrap();
730 let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
731
732 match assistant_message {
734 Message::Assistant { content, .. } => {
735 assert_eq!(
736 content.first(),
737 AssistantContent::Text(message::Text {
738 text: "Hello there, how may I assist you today?".to_string()
739 })
740 );
741 }
742 _ => panic!("Expected assistant message"),
743 }
744
745 match user_message {
746 Message::User { content } => {
747 assert_eq!(
748 content.first(),
749 UserContent::Text(message::Text {
750 text: "What can you help me with?".to_string()
751 })
752 );
753 }
754 _ => panic!("Expected user message"),
755 }
756
757 match assistant_message_array {
759 Message::Assistant { content, .. } => {
760 assert_eq!(
761 content.first(),
762 AssistantContent::Text(message::Text {
763 text: "Hello there, how may I assist you today?".to_string()
764 })
765 );
766 }
767 _ => panic!("Expected assistant message"),
768 }
769 }
770
771 #[test]
772 fn test_message_conversion() {
773 let original_message = message::Message::User {
775 content: OneOrMany::one(message::UserContent::text("Hello")),
776 };
777
778 let mira_value: serde_json::Value = original_message.clone().into();
780
781 let converted_message: Message = mira_value.try_into().unwrap();
783
784 assert_eq!(original_message, converted_message);
785 }
786
787 #[test]
788 fn test_completion_response_conversion() {
789 let mira_response = CompletionResponse::Structured {
790 id: "resp_123".to_string(),
791 object: "chat.completion".to_string(),
792 created: 1234567890,
793 model: "deepseek-r1".to_string(),
794 choices: vec![ChatChoice {
795 message: RawMessage {
796 role: "assistant".to_string(),
797 content: "Test response".to_string(),
798 },
799 finish_reason: Some("stop".to_string()),
800 index: Some(0),
801 }],
802 usage: Some(Usage {
803 prompt_tokens: 10,
804 total_tokens: 20,
805 }),
806 };
807
808 let completion_response: completion::CompletionResponse<CompletionResponse> =
809 mira_response.try_into().unwrap();
810
811 assert_eq!(
812 completion_response.choice.first(),
813 completion::AssistantContent::text("Test response")
814 );
815 }
816 #[test]
817 fn test_client_initialization() {
818 let _client =
819 crate::providers::mira::Client::new("dummy-key").expect("Client::new() failed");
820 let _client_from_builder = crate::providers::mira::Client::builder()
821 .api_key("dummy-key")
822 .build()
823 .expect("Client::builder() failed");
824 }
825}