1use crate::client::{
11 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
12 ProviderClient,
13};
14use crate::http_client::{self, HttpClientExt};
15use crate::message::{Document, DocumentSourceKind};
16use crate::providers::openai;
17use crate::providers::openai::send_compatible_streaming_request;
18use crate::streaming::StreamingCompletionResponse;
19use crate::{
20 OneOrMany,
21 completion::{self, CompletionError, CompletionRequest},
22 message::{self, AssistantContent, Message, UserContent},
23};
24use serde::{Deserialize, Serialize};
25use std::string::FromUtf8Error;
26use thiserror::Error;
27use tracing::{self, Instrument, info_span};
28
29#[derive(Debug, Default, Clone, Copy)]
30pub struct MiraExt;
31#[derive(Debug, Default, Clone, Copy)]
32pub struct MiraBuilder;
33
34type MiraApiKey = BearerAuth;
35
36impl Provider for MiraExt {
37 type Builder = MiraBuilder;
38
39 const VERIFY_PATH: &'static str = "/user-credits";
40}
41
42impl<H> Capabilities<H> for MiraExt {
43 type Completion = Capable<CompletionModel<H>>;
44 type Embeddings = Nothing;
45 type Transcription = Nothing;
46 type ModelListing = Nothing;
47
48 #[cfg(feature = "image")]
49 type ImageGeneration = Nothing;
50
51 #[cfg(feature = "audio")]
52 type AudioGeneration = Nothing;
53}
54
55impl DebugExt for MiraExt {}
56
57impl ProviderBuilder for MiraBuilder {
58 type Extension<H>
59 = MiraExt
60 where
61 H: HttpClientExt;
62 type ApiKey = MiraApiKey;
63
64 const BASE_URL: &'static str = MIRA_API_BASE_URL;
65
66 fn build<H>(
67 _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
68 ) -> http_client::Result<Self::Extension<H>>
69 where
70 H: HttpClientExt,
71 {
72 Ok(MiraExt)
73 }
74}
75
76pub type Client<H = reqwest::Client> = client::Client<MiraExt, H>;
77pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<MiraBuilder, MiraApiKey, H>;
78
79#[derive(Debug, Error)]
80pub enum MiraError {
81 #[error("Invalid API key")]
82 InvalidApiKey,
83 #[error("API error: {0}")]
84 ApiError(u16),
85 #[error("Request error: {0}")]
86 RequestError(#[from] http_client::Error),
87 #[error("UTF-8 error: {0}")]
88 Utf8Error(#[from] FromUtf8Error),
89 #[error("JSON error: {0}")]
90 JsonError(#[from] serde_json::Error),
91}
92
93#[derive(Debug, Deserialize)]
94struct ApiErrorResponse {
95 message: String,
96}
97
98#[derive(Debug, Deserialize, Clone, Serialize)]
99pub struct RawMessage {
100 pub role: String,
101 pub content: String,
102}
103
104const MIRA_API_BASE_URL: &str = "https://api.mira.network";
105
106impl TryFrom<RawMessage> for message::Message {
107 type Error = CompletionError;
108
109 fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
110 match raw.role.as_str() {
111 "user" => Ok(message::Message::User {
112 content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
113 }),
114 "assistant" => Ok(message::Message::Assistant {
115 id: None,
116 content: OneOrMany::one(AssistantContent::Text(message::Text {
117 text: raw.content,
118 })),
119 }),
120 _ => Err(CompletionError::ResponseError(format!(
121 "Unsupported message role: {}",
122 raw.role
123 ))),
124 }
125 }
126}
127
128#[derive(Debug, Deserialize, Serialize)]
129#[serde(untagged)]
130pub enum CompletionResponse {
131 Structured {
132 id: String,
133 object: String,
134 created: u64,
135 model: String,
136 choices: Vec<ChatChoice>,
137 #[serde(skip_serializing_if = "Option::is_none")]
138 usage: Option<Usage>,
139 },
140 Simple(String),
141}
142
143#[derive(Debug, Deserialize, Serialize)]
144pub struct ChatChoice {
145 pub message: RawMessage,
146 #[serde(default)]
147 pub finish_reason: Option<String>,
148 #[serde(default)]
149 pub index: Option<usize>,
150}
151
152#[derive(Debug, Deserialize, Serialize)]
153struct ModelsResponse {
154 data: Vec<ModelInfo>,
155}
156
157#[derive(Debug, Deserialize, Serialize)]
158struct ModelInfo {
159 id: String,
160}
161
162impl<T> Client<T>
163where
164 T: HttpClientExt + 'static,
165{
166 pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
168 let req = self.get("/v1/models").and_then(|req| {
169 req.body(http_client::NoBody)
170 .map_err(http_client::Error::Protocol)
171 })?;
172
173 let response = self.send(req).await?;
174
175 let status = response.status();
176
177 if !status.is_success() {
178 let error_text = http_client::text(response).await.unwrap_or_default();
180 tracing::error!("Error response: {}", error_text);
181 return Err(MiraError::ApiError(status.as_u16()));
182 }
183
184 let response_text = http_client::text(response).await?;
185
186 let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
187 tracing::error!("Failed to parse response: {}", e);
188 MiraError::JsonError(e)
189 })?;
190
191 Ok(models.data.into_iter().map(|model| model.id).collect())
192 }
193}
194
195impl ProviderClient for Client {
196 type Input = String;
197
198 fn from_env() -> Self {
201 let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
202 Self::new(&api_key).unwrap()
203 }
204
205 fn from_val(input: Self::Input) -> Self {
206 Self::new(&input).unwrap()
207 }
208}
209
210#[derive(Debug, Serialize, Deserialize)]
211pub(super) struct MiraCompletionRequest {
212 model: String,
213 pub messages: Vec<RawMessage>,
214 #[serde(skip_serializing_if = "Option::is_none")]
215 temperature: Option<f64>,
216 #[serde(skip_serializing_if = "Option::is_none")]
217 max_tokens: Option<u64>,
218 pub stream: bool,
219}
220
221impl TryFrom<(&str, CompletionRequest)> for MiraCompletionRequest {
222 type Error = CompletionError;
223
224 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
225 if req.output_schema.is_some() {
226 tracing::warn!("Structured outputs currently not supported for Mira");
227 }
228 let model = req.model.clone().unwrap_or_else(|| model.to_string());
229 let mut messages = Vec::new();
230
231 if let Some(content) = &req.preamble {
232 messages.push(RawMessage {
233 role: "user".to_string(),
234 content: content.to_string(),
235 });
236 }
237
238 if let Some(Message::User { content }) = req.normalized_documents() {
239 let text = content
240 .into_iter()
241 .filter_map(|doc| match doc {
242 UserContent::Document(Document {
243 data: DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data),
244 ..
245 }) => Some(data),
246 UserContent::Text(text) => Some(text.text),
247
248 _ => None,
250 })
251 .collect::<Vec<_>>()
252 .join("\n");
253
254 messages.push(RawMessage {
255 role: "user".to_string(),
256 content: text,
257 });
258 }
259
260 for msg in req.chat_history {
261 let (role, content) = match msg {
262 Message::User { content } => {
263 let text = content
264 .iter()
265 .map(|c| match c {
266 UserContent::Text(text) => &text.text,
267 _ => "",
268 })
269 .collect::<Vec<_>>()
270 .join("\n");
271 ("user", text)
272 }
273 Message::Assistant { content, .. } => {
274 let text = content
275 .iter()
276 .map(|c| match c {
277 AssistantContent::Text(text) => &text.text,
278 _ => "",
279 })
280 .collect::<Vec<_>>()
281 .join("\n");
282 ("assistant", text)
283 }
284 };
285 messages.push(RawMessage {
286 role: role.to_string(),
287 content,
288 });
289 }
290
291 Ok(Self {
292 model: model.to_string(),
293 messages,
294 temperature: req.temperature,
295 max_tokens: req.max_tokens,
296 stream: false,
297 })
298 }
299}
300
301#[derive(Clone)]
302pub struct CompletionModel<T = reqwest::Client> {
303 client: Client<T>,
304 pub model: String,
306}
307
308impl<T> CompletionModel<T> {
309 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
310 Self {
311 client,
312 model: model.into(),
313 }
314 }
315}
316
317impl<T> completion::CompletionModel for CompletionModel<T>
318where
319 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
320{
321 type Response = CompletionResponse;
322 type StreamingResponse = openai::StreamingCompletionResponse;
323
324 type Client = Client<T>;
325
326 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
327 Self::new(client.clone(), model)
328 }
329
330 async fn completion(
331 &self,
332 completion_request: CompletionRequest,
333 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
334 let span = if tracing::Span::current().is_disabled() {
335 info_span!(
336 target: "rig::completions",
337 "chat",
338 gen_ai.operation.name = "chat",
339 gen_ai.provider.name = "mira",
340 gen_ai.request.model = self.model,
341 gen_ai.system_instructions = tracing::field::Empty,
342 gen_ai.response.id = tracing::field::Empty,
343 gen_ai.response.model = tracing::field::Empty,
344 gen_ai.usage.output_tokens = tracing::field::Empty,
345 gen_ai.usage.input_tokens = tracing::field::Empty,
346 )
347 } else {
348 tracing::Span::current()
349 };
350
351 span.record("gen_ai.system_instructions", &completion_request.preamble);
352
353 if !completion_request.tools.is_empty() {
354 tracing::warn!(target: "rig::completions",
355 "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
356 len = completion_request.tools.len()
357 );
358 }
359
360 if completion_request.tool_choice.is_some() {
361 tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
362 }
363
364 if completion_request.additional_params.is_some() {
365 tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
366 }
367
368 let request = MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
369
370 if tracing::enabled!(tracing::Level::TRACE) {
371 tracing::trace!(target: "rig::completions",
372 "Mira completion request: {}",
373 serde_json::to_string_pretty(&request)?
374 );
375 }
376
377 let body = serde_json::to_vec(&request)?;
378
379 let req = self
380 .client
381 .post("/v1/chat/completions")?
382 .body(body)
383 .map_err(http_client::Error::from)?;
384
385 let async_block = async move {
386 let response = self
387 .client
388 .send::<_, bytes::Bytes>(req)
389 .await
390 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
391
392 let status = response.status();
393 let response_body = response.into_body().into_future().await?.to_vec();
394
395 if !status.is_success() {
396 let status = status.as_u16();
397 let error_text = String::from_utf8_lossy(&response_body).to_string();
398 return Err(CompletionError::ProviderError(format!(
399 "API error: {status} - {error_text}"
400 )));
401 }
402
403 let response: CompletionResponse = serde_json::from_slice(&response_body)?;
404
405 if tracing::enabled!(tracing::Level::TRACE) {
406 tracing::trace!(target: "rig::completions",
407 "Mira completion response: {}",
408 serde_json::to_string_pretty(&response)?
409 );
410 }
411
412 if let CompletionResponse::Structured {
413 id, model, usage, ..
414 } = &response
415 {
416 let span = tracing::Span::current();
417 span.record("gen_ai.response.model_name", model);
418 span.record("gen_ai.response.id", id);
419 if let Some(usage) = usage {
420 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
421 span.record(
422 "gen_ai.usage.output_tokens",
423 usage.total_tokens - usage.prompt_tokens,
424 );
425 }
426 }
427
428 response.try_into()
429 };
430
431 async_block.instrument(span).await
432 }
433
434 async fn stream(
435 &self,
436 completion_request: CompletionRequest,
437 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
438 let span = if tracing::Span::current().is_disabled() {
439 info_span!(
440 target: "rig::completions",
441 "chat_streaming",
442 gen_ai.operation.name = "chat_streaming",
443 gen_ai.provider.name = "mira",
444 gen_ai.request.model = self.model,
445 gen_ai.system_instructions = tracing::field::Empty,
446 gen_ai.response.id = tracing::field::Empty,
447 gen_ai.response.model = tracing::field::Empty,
448 gen_ai.usage.output_tokens = tracing::field::Empty,
449 gen_ai.usage.input_tokens = tracing::field::Empty,
450 )
451 } else {
452 tracing::Span::current()
453 };
454
455 span.record("gen_ai.system_instructions", &completion_request.preamble);
456
457 if !completion_request.tools.is_empty() {
458 tracing::warn!(target: "rig::completions",
459 "Tool calls are not supported by Mira AI. {len} tools will be ignored.",
460 len = completion_request.tools.len()
461 );
462 }
463
464 if completion_request.tool_choice.is_some() {
465 tracing::warn!("WARNING: `tool_choice` not supported on Mira AI");
466 }
467
468 if completion_request.additional_params.is_some() {
469 tracing::warn!("WARNING: Additional parameters not supported on Mira AI");
470 }
471 let mut request =
472 MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
473 request.stream = true;
474
475 if tracing::enabled!(tracing::Level::TRACE) {
476 tracing::trace!(target: "rig::completions",
477 "Mira completion request: {}",
478 serde_json::to_string_pretty(&request)?
479 );
480 }
481
482 let body = serde_json::to_vec(&request)?;
483
484 let req = self
485 .client
486 .post("/v1/chat/completions")?
487 .body(body)
488 .map_err(http_client::Error::from)?;
489
490 send_compatible_streaming_request(self.client.clone(), req)
491 .instrument(span)
492 .await
493 }
494}
495
496impl From<ApiErrorResponse> for CompletionError {
497 fn from(err: ApiErrorResponse) -> Self {
498 CompletionError::ProviderError(err.message)
499 }
500}
501
502impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
503 type Error = CompletionError;
504
505 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
506 let (content, usage) = match &response {
507 CompletionResponse::Structured { choices, usage, .. } => {
508 let choice = choices.first().ok_or_else(|| {
509 CompletionError::ResponseError("Response contained no choices".to_owned())
510 })?;
511
512 let usage = usage
513 .as_ref()
514 .map(|usage| completion::Usage {
515 input_tokens: usage.prompt_tokens as u64,
516 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
517 total_tokens: usage.total_tokens as u64,
518 cached_input_tokens: 0,
519 })
520 .unwrap_or_default();
521
522 let message = message::Message::try_from(choice.message.clone())?;
524
525 let content = match message {
526 Message::Assistant { content, .. } => {
527 if content.is_empty() {
528 return Err(CompletionError::ResponseError(
529 "Response contained empty content".to_owned(),
530 ));
531 }
532
533 for c in content.iter() {
535 if !matches!(c, AssistantContent::Text(_)) {
536 tracing::warn!(target: "rig",
537 "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
538 );
539 }
540 }
541
542 content.iter().map(|c| {
543 match c {
544 AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
545 other => Err(CompletionError::ResponseError(
546 format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
547 ))
548 }
549 }).collect::<Result<Vec<_>, _>>()?
550 }
551 Message::User { .. } => {
552 tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
553 return Err(CompletionError::ResponseError(
554 "Received user message in response where assistant message was expected".to_owned()
555 ));
556 }
557 };
558
559 (content, usage)
560 }
561 CompletionResponse::Simple(text) => (
562 vec![completion::AssistantContent::text(text)],
563 completion::Usage::new(),
564 ),
565 };
566
567 let choice = OneOrMany::many(content).map_err(|_| {
568 CompletionError::ResponseError(
569 "Response contained no message or tool call (empty)".to_owned(),
570 )
571 })?;
572
573 Ok(completion::CompletionResponse {
574 choice,
575 usage,
576 raw_response: response,
577 message_id: None,
578 })
579 }
580}
581
582#[derive(Clone, Debug, Deserialize, Serialize)]
583pub struct Usage {
584 pub prompt_tokens: usize,
585 pub total_tokens: usize,
586}
587
588impl std::fmt::Display for Usage {
589 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
590 write!(
591 f,
592 "Prompt tokens: {} Total tokens: {}",
593 self.prompt_tokens, self.total_tokens
594 )
595 }
596}
597
598impl From<Message> for serde_json::Value {
599 fn from(msg: Message) -> Self {
600 match msg {
601 Message::User { content } => {
602 let text = content
603 .iter()
604 .map(|c| match c {
605 UserContent::Text(text) => &text.text,
606 _ => "",
607 })
608 .collect::<Vec<_>>()
609 .join("\n");
610 serde_json::json!({
611 "role": "user",
612 "content": text
613 })
614 }
615 Message::Assistant { content, .. } => {
616 let text = content
617 .iter()
618 .map(|c| match c {
619 AssistantContent::Text(text) => &text.text,
620 _ => "",
621 })
622 .collect::<Vec<_>>()
623 .join("\n");
624 serde_json::json!({
625 "role": "assistant",
626 "content": text
627 })
628 }
629 }
630 }
631}
632
633impl TryFrom<serde_json::Value> for Message {
634 type Error = CompletionError;
635
636 fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
637 let role = value["role"].as_str().ok_or_else(|| {
638 CompletionError::ResponseError("Message missing role field".to_owned())
639 })?;
640
641 let content = match value.get("content") {
643 Some(content) => match content {
644 serde_json::Value::String(s) => s.clone(),
645 serde_json::Value::Array(arr) => arr
646 .iter()
647 .filter_map(|c| {
648 c.get("text")
649 .and_then(|t| t.as_str())
650 .map(|text| text.to_string())
651 })
652 .collect::<Vec<_>>()
653 .join("\n"),
654 _ => {
655 return Err(CompletionError::ResponseError(
656 "Message content must be string or array".to_owned(),
657 ));
658 }
659 },
660 None => {
661 return Err(CompletionError::ResponseError(
662 "Message missing content field".to_owned(),
663 ));
664 }
665 };
666
667 match role {
668 "user" => Ok(Message::User {
669 content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
670 }),
671 "assistant" => Ok(Message::Assistant {
672 id: None,
673 content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
674 }),
675 _ => Err(CompletionError::ResponseError(format!(
676 "Unsupported message role: {role}"
677 ))),
678 }
679 }
680}
681
682#[cfg(test)]
683mod tests {
684 use super::*;
685 use crate::message::UserContent;
686 use serde_json::json;
687
688 #[test]
689 fn test_deserialize_message() {
690 let assistant_message_json = json!({
692 "role": "assistant",
693 "content": "Hello there, how may I assist you today?"
694 });
695
696 let user_message_json = json!({
697 "role": "user",
698 "content": "What can you help me with?"
699 });
700
701 let assistant_message_array_json = json!({
703 "role": "assistant",
704 "content": [{
705 "type": "text",
706 "text": "Hello there, how may I assist you today?"
707 }]
708 });
709
710 let assistant_message = Message::try_from(assistant_message_json).unwrap();
711 let user_message = Message::try_from(user_message_json).unwrap();
712 let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
713
714 match assistant_message {
716 Message::Assistant { content, .. } => {
717 assert_eq!(
718 content.first(),
719 AssistantContent::Text(message::Text {
720 text: "Hello there, how may I assist you today?".to_string()
721 })
722 );
723 }
724 _ => panic!("Expected assistant message"),
725 }
726
727 match user_message {
728 Message::User { content } => {
729 assert_eq!(
730 content.first(),
731 UserContent::Text(message::Text {
732 text: "What can you help me with?".to_string()
733 })
734 );
735 }
736 _ => panic!("Expected user message"),
737 }
738
739 match assistant_message_array {
741 Message::Assistant { content, .. } => {
742 assert_eq!(
743 content.first(),
744 AssistantContent::Text(message::Text {
745 text: "Hello there, how may I assist you today?".to_string()
746 })
747 );
748 }
749 _ => panic!("Expected assistant message"),
750 }
751 }
752
753 #[test]
754 fn test_message_conversion() {
755 let original_message = message::Message::User {
757 content: OneOrMany::one(message::UserContent::text("Hello")),
758 };
759
760 let mira_value: serde_json::Value = original_message.clone().into();
762
763 let converted_message: Message = mira_value.try_into().unwrap();
765
766 assert_eq!(original_message, converted_message);
767 }
768
769 #[test]
770 fn test_completion_response_conversion() {
771 let mira_response = CompletionResponse::Structured {
772 id: "resp_123".to_string(),
773 object: "chat.completion".to_string(),
774 created: 1234567890,
775 model: "deepseek-r1".to_string(),
776 choices: vec![ChatChoice {
777 message: RawMessage {
778 role: "assistant".to_string(),
779 content: "Test response".to_string(),
780 },
781 finish_reason: Some("stop".to_string()),
782 index: Some(0),
783 }],
784 usage: Some(Usage {
785 prompt_tokens: 10,
786 total_tokens: 20,
787 }),
788 };
789
790 let completion_response: completion::CompletionResponse<CompletionResponse> =
791 mira_response.try_into().unwrap();
792
793 assert_eq!(
794 completion_response.choice.first(),
795 completion::AssistantContent::text("Test response")
796 );
797 }
798 #[test]
799 fn test_client_initialization() {
800 let _client =
801 crate::providers::mira::Client::new("dummy-key").expect("Client::new() failed");
802 let _client_from_builder = crate::providers::mira::Client::builder()
803 .api_key("dummy-key")
804 .build()
805 .expect("Client::builder() failed");
806 }
807}