1use crate::client::{ClientBuilderError, CompletionClient, ProviderClient};
11use crate::json_utils::merge;
12use crate::providers::openai;
13use crate::providers::openai::send_compatible_streaming_request;
14use crate::streaming::StreamingCompletionResponse;
15use crate::{
16 OneOrMany,
17 completion::{self, CompletionError, CompletionRequest},
18 impl_conversion_traits,
19 message::{self, AssistantContent, Message, UserContent},
20};
21use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
22use serde::{Deserialize, Serialize};
23use serde_json::{Value, json};
24use std::string::FromUtf8Error;
25use thiserror::Error;
26use tracing;
27
28#[derive(Debug, Error)]
29pub enum MiraError {
30 #[error("Invalid API key")]
31 InvalidApiKey,
32 #[error("API error: {0}")]
33 ApiError(u16),
34 #[error("Request error: {0}")]
35 RequestError(#[from] reqwest::Error),
36 #[error("UTF-8 error: {0}")]
37 Utf8Error(#[from] FromUtf8Error),
38 #[error("JSON error: {0}")]
39 JsonError(#[from] serde_json::Error),
40}
41
42#[derive(Debug, Deserialize)]
43struct ApiErrorResponse {
44 message: String,
45}
46
47#[derive(Debug, Deserialize, Clone, Serialize)]
48pub struct RawMessage {
49 pub role: String,
50 pub content: String,
51}
52
53const MIRA_API_BASE_URL: &str = "https://api.mira.network";
54
55impl TryFrom<RawMessage> for message::Message {
56 type Error = CompletionError;
57
58 fn try_from(raw: RawMessage) -> Result<Self, Self::Error> {
59 match raw.role.as_str() {
60 "user" => Ok(message::Message::User {
61 content: OneOrMany::one(UserContent::Text(message::Text { text: raw.content })),
62 }),
63 "assistant" => Ok(message::Message::Assistant {
64 id: None,
65 content: OneOrMany::one(AssistantContent::Text(message::Text {
66 text: raw.content,
67 })),
68 }),
69 _ => Err(CompletionError::ResponseError(format!(
70 "Unsupported message role: {}",
71 raw.role
72 ))),
73 }
74 }
75}
76
77#[derive(Debug, Deserialize, Serialize)]
78#[serde(untagged)]
79pub enum CompletionResponse {
80 Structured {
81 id: String,
82 object: String,
83 created: u64,
84 model: String,
85 choices: Vec<ChatChoice>,
86 #[serde(skip_serializing_if = "Option::is_none")]
87 usage: Option<Usage>,
88 },
89 Simple(String),
90}
91
92#[derive(Debug, Deserialize, Serialize)]
93pub struct ChatChoice {
94 pub message: RawMessage,
95 #[serde(default)]
96 pub finish_reason: Option<String>,
97 #[serde(default)]
98 pub index: Option<usize>,
99}
100
101#[derive(Debug, Deserialize, Serialize)]
102struct ModelsResponse {
103 data: Vec<ModelInfo>,
104}
105
106#[derive(Debug, Deserialize, Serialize)]
107struct ModelInfo {
108 id: String,
109}
110
111pub struct ClientBuilder<'a> {
112 api_key: &'a str,
113 base_url: &'a str,
114 http_client: Option<reqwest::Client>,
115}
116
117impl<'a> ClientBuilder<'a> {
118 pub fn new(api_key: &'a str) -> Self {
119 Self {
120 api_key,
121 base_url: MIRA_API_BASE_URL,
122 http_client: None,
123 }
124 }
125
126 pub fn base_url(mut self, base_url: &'a str) -> Self {
127 self.base_url = base_url;
128 self
129 }
130
131 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
132 self.http_client = Some(client);
133 self
134 }
135
136 pub fn build(self) -> Result<Client, ClientBuilderError> {
137 let mut headers = HeaderMap::new();
138 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
139 headers.insert(
140 reqwest::header::ACCEPT,
141 HeaderValue::from_static("application/json"),
142 );
143 headers.insert(
144 reqwest::header::USER_AGENT,
145 HeaderValue::from_static("rig-client/1.0"),
146 );
147 let http_client = if let Some(http_client) = self.http_client {
148 http_client
149 } else {
150 reqwest::Client::builder().build()?
151 };
152
153 Ok(Client {
154 base_url: self.base_url.to_string(),
155 http_client,
156 api_key: self.api_key.to_string(),
157 headers,
158 })
159 }
160}
161
162#[derive(Clone)]
163pub struct Client {
165 base_url: String,
166 http_client: reqwest::Client,
167 api_key: String,
168 headers: HeaderMap,
169}
170
171impl std::fmt::Debug for Client {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 f.debug_struct("Client")
174 .field("base_url", &self.base_url)
175 .field("http_client", &self.http_client)
176 .field("api_key", &"<REDACTED>")
177 .field("headers", &self.headers)
178 .finish()
179 }
180}
181
182impl Client {
183 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
194 ClientBuilder::new(api_key)
195 }
196
197 pub fn new(api_key: &str) -> Self {
202 Self::builder(api_key)
203 .build()
204 .expect("Mira client should build")
205 }
206
207 pub async fn list_models(&self) -> Result<Vec<String>, MiraError> {
209 let url = format!("{}/v1/models", self.base_url);
210
211 let response = self
212 .http_client
213 .get(&url)
214 .bearer_auth(&self.api_key)
215 .headers(self.headers.clone())
216 .send()
217 .await?;
218
219 let status = response.status();
220
221 if !status.is_success() {
222 let _error_text = response.text().await.unwrap_or_default();
224 tracing::error!("Error response: {}", _error_text);
225 return Err(MiraError::ApiError(status.as_u16()));
226 }
227
228 let response_text = response.text().await?;
229
230 let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| {
231 tracing::error!("Failed to parse response: {}", e);
232 MiraError::JsonError(e)
233 })?;
234
235 Ok(models.data.into_iter().map(|model| model.id).collect())
236 }
237}
238
239impl ProviderClient for Client {
240 fn from_env() -> Self {
243 let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set");
244 Self::new(&api_key)
245 }
246
247 fn from_val(input: crate::client::ProviderValue) -> Self {
248 let crate::client::ProviderValue::Simple(api_key) = input else {
249 panic!("Incorrect provider value type")
250 };
251 Self::new(&api_key)
252 }
253}
254
255impl CompletionClient for Client {
256 type CompletionModel = CompletionModel;
257 fn completion_model(&self, model: &str) -> CompletionModel {
259 CompletionModel::new(self.to_owned(), model)
260 }
261}
262
263impl_conversion_traits!(
264 AsEmbeddings,
265 AsTranscription,
266 AsImageGeneration,
267 AsAudioGeneration for Client
268);
269
270#[derive(Clone)]
271pub struct CompletionModel {
272 client: Client,
273 pub model: String,
275}
276
277impl CompletionModel {
278 pub fn new(client: Client, model: &str) -> Self {
279 Self {
280 client,
281 model: model.to_string(),
282 }
283 }
284
285 fn create_completion_request(
286 &self,
287 completion_request: CompletionRequest,
288 ) -> Result<Value, CompletionError> {
289 let mut messages = Vec::new();
290
291 if let Some(preamble) = &completion_request.preamble {
293 messages.push(serde_json::json!({
294 "role": "user",
295 "content": preamble.to_string()
296 }));
297 }
298
299 if let Some(Message::User { content }) = completion_request.normalized_documents() {
301 let text = content
302 .into_iter()
303 .filter_map(|doc| match doc {
304 UserContent::Document(doc) => Some(doc.data),
305 UserContent::Text(text) => Some(text.text),
306
307 _ => None,
309 })
310 .collect::<Vec<_>>()
311 .join("\n");
312
313 messages.push(serde_json::json!({
314 "role": "user",
315 "content": text
316 }));
317 }
318
319 for msg in completion_request.chat_history {
321 let (role, content) = match msg {
322 Message::User { content } => {
323 let text = content
324 .iter()
325 .map(|c| match c {
326 UserContent::Text(text) => &text.text,
327 _ => "",
328 })
329 .collect::<Vec<_>>()
330 .join("\n");
331 ("user", text)
332 }
333 Message::Assistant { content, .. } => {
334 let text = content
335 .iter()
336 .map(|c| match c {
337 AssistantContent::Text(text) => &text.text,
338 _ => "",
339 })
340 .collect::<Vec<_>>()
341 .join("\n");
342 ("assistant", text)
343 }
344 };
345 messages.push(serde_json::json!({
346 "role": role,
347 "content": content
348 }));
349 }
350
351 let request = serde_json::json!({
352 "model": self.model,
353 "messages": messages,
354 "temperature": completion_request.temperature.map(|t| t as f32).unwrap_or(0.7),
355 "max_tokens": completion_request.max_tokens.map(|t| t as u32).unwrap_or(100),
356 "stream": false
357 });
358
359 Ok(request)
360 }
361}
362
363impl completion::CompletionModel for CompletionModel {
364 type Response = CompletionResponse;
365 type StreamingResponse = openai::StreamingCompletionResponse;
366
367 #[cfg_attr(feature = "worker", worker::send)]
368 async fn completion(
369 &self,
370 completion_request: CompletionRequest,
371 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
372 if !completion_request.tools.is_empty() {
373 tracing::warn!(target: "rig",
374 "Tool calls are not supported by the Mira provider. {} tools will be ignored.",
375 completion_request.tools.len()
376 );
377 }
378
379 let mira_request = self.create_completion_request(completion_request)?;
380
381 let response = self
382 .client
383 .http_client
384 .post(format!("{}/v1/chat/completions", self.client.base_url))
385 .bearer_auth(&self.client.api_key)
386 .headers(self.client.headers.clone())
387 .json(&mira_request)
388 .send()
389 .await
390 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
391
392 if !response.status().is_success() {
393 let status = response.status().as_u16();
394 let error_text = response.text().await.unwrap_or_default();
395 return Err(CompletionError::ProviderError(format!(
396 "API error: {status} - {error_text}"
397 )));
398 }
399
400 let response: CompletionResponse = response
401 .json()
402 .await
403 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
404
405 response.try_into()
406 }
407
408 #[cfg_attr(feature = "worker", worker::send)]
409 async fn stream(
410 &self,
411 completion_request: CompletionRequest,
412 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
413 let mut request = self.create_completion_request(completion_request)?;
414
415 request = merge(request, json!({"stream": true}));
416
417 let builder = self
418 .client
419 .http_client
420 .post(format!("{}/v1/chat/completions", self.client.base_url))
421 .headers(self.client.headers.clone())
422 .json(&request);
423
424 send_compatible_streaming_request(builder).await
425 }
426}
427
428impl From<ApiErrorResponse> for CompletionError {
429 fn from(err: ApiErrorResponse) -> Self {
430 CompletionError::ProviderError(err.message)
431 }
432}
433
434impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
435 type Error = CompletionError;
436
437 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
438 let (content, usage) = match &response {
439 CompletionResponse::Structured { choices, usage, .. } => {
440 let choice = choices.first().ok_or_else(|| {
441 CompletionError::ResponseError("Response contained no choices".to_owned())
442 })?;
443
444 let usage = usage
445 .as_ref()
446 .map(|usage| completion::Usage {
447 input_tokens: usage.prompt_tokens as u64,
448 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
449 total_tokens: usage.total_tokens as u64,
450 })
451 .unwrap_or_default();
452
453 let message = message::Message::try_from(choice.message.clone())?;
455
456 let content = match message {
457 Message::Assistant { content, .. } => {
458 if content.is_empty() {
459 return Err(CompletionError::ResponseError(
460 "Response contained empty content".to_owned(),
461 ));
462 }
463
464 for c in content.iter() {
466 if !matches!(c, AssistantContent::Text(_)) {
467 tracing::warn!(target: "rig",
468 "Unsupported content type encountered: {:?}. The Mira provider currently only supports text content", c
469 );
470 }
471 }
472
473 content.iter().map(|c| {
474 match c {
475 AssistantContent::Text(text) => Ok(completion::AssistantContent::text(&text.text)),
476 other => Err(CompletionError::ResponseError(
477 format!("Unsupported content type: {other:?}. The Mira provider currently only supports text content")
478 ))
479 }
480 }).collect::<Result<Vec<_>, _>>()?
481 }
482 Message::User { .. } => {
483 tracing::warn!(target: "rig", "Received user message in response where assistant message was expected");
484 return Err(CompletionError::ResponseError(
485 "Received user message in response where assistant message was expected".to_owned()
486 ));
487 }
488 };
489
490 (content, usage)
491 }
492 CompletionResponse::Simple(text) => (
493 vec![completion::AssistantContent::text(text)],
494 completion::Usage::new(),
495 ),
496 };
497
498 let choice = OneOrMany::many(content).map_err(|_| {
499 CompletionError::ResponseError(
500 "Response contained no message or tool call (empty)".to_owned(),
501 )
502 })?;
503
504 Ok(completion::CompletionResponse {
505 choice,
506 usage,
507 raw_response: response,
508 })
509 }
510}
511
512#[derive(Clone, Debug, Deserialize, Serialize)]
513pub struct Usage {
514 pub prompt_tokens: usize,
515 pub total_tokens: usize,
516}
517
518impl std::fmt::Display for Usage {
519 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
520 write!(
521 f,
522 "Prompt tokens: {} Total tokens: {}",
523 self.prompt_tokens, self.total_tokens
524 )
525 }
526}
527
528impl From<Message> for serde_json::Value {
529 fn from(msg: Message) -> Self {
530 match msg {
531 Message::User { content } => {
532 let text = content
533 .iter()
534 .map(|c| match c {
535 UserContent::Text(text) => &text.text,
536 _ => "",
537 })
538 .collect::<Vec<_>>()
539 .join("\n");
540 serde_json::json!({
541 "role": "user",
542 "content": text
543 })
544 }
545 Message::Assistant { content, .. } => {
546 let text = content
547 .iter()
548 .map(|c| match c {
549 AssistantContent::Text(text) => &text.text,
550 _ => "",
551 })
552 .collect::<Vec<_>>()
553 .join("\n");
554 serde_json::json!({
555 "role": "assistant",
556 "content": text
557 })
558 }
559 }
560 }
561}
562
563impl TryFrom<serde_json::Value> for Message {
564 type Error = CompletionError;
565
566 fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
567 let role = value["role"].as_str().ok_or_else(|| {
568 CompletionError::ResponseError("Message missing role field".to_owned())
569 })?;
570
571 let content = match value.get("content") {
573 Some(content) => match content {
574 serde_json::Value::String(s) => s.clone(),
575 serde_json::Value::Array(arr) => arr
576 .iter()
577 .filter_map(|c| {
578 c.get("text")
579 .and_then(|t| t.as_str())
580 .map(|text| text.to_string())
581 })
582 .collect::<Vec<_>>()
583 .join("\n"),
584 _ => {
585 return Err(CompletionError::ResponseError(
586 "Message content must be string or array".to_owned(),
587 ));
588 }
589 },
590 None => {
591 return Err(CompletionError::ResponseError(
592 "Message missing content field".to_owned(),
593 ));
594 }
595 };
596
597 match role {
598 "user" => Ok(Message::User {
599 content: OneOrMany::one(UserContent::Text(message::Text { text: content })),
600 }),
601 "assistant" => Ok(Message::Assistant {
602 id: None,
603 content: OneOrMany::one(AssistantContent::Text(message::Text { text: content })),
604 }),
605 _ => Err(CompletionError::ResponseError(format!(
606 "Unsupported message role: {role}"
607 ))),
608 }
609 }
610}
611
612#[cfg(test)]
613mod tests {
614 use super::*;
615 use crate::message::UserContent;
616 use serde_json::json;
617
618 #[test]
619 fn test_deserialize_message() {
620 let assistant_message_json = json!({
622 "role": "assistant",
623 "content": "Hello there, how may I assist you today?"
624 });
625
626 let user_message_json = json!({
627 "role": "user",
628 "content": "What can you help me with?"
629 });
630
631 let assistant_message_array_json = json!({
633 "role": "assistant",
634 "content": [{
635 "type": "text",
636 "text": "Hello there, how may I assist you today?"
637 }]
638 });
639
640 let assistant_message = Message::try_from(assistant_message_json).unwrap();
641 let user_message = Message::try_from(user_message_json).unwrap();
642 let assistant_message_array = Message::try_from(assistant_message_array_json).unwrap();
643
644 match assistant_message {
646 Message::Assistant { content, .. } => {
647 assert_eq!(
648 content.first(),
649 AssistantContent::Text(message::Text {
650 text: "Hello there, how may I assist you today?".to_string()
651 })
652 );
653 }
654 _ => panic!("Expected assistant message"),
655 }
656
657 match user_message {
658 Message::User { content } => {
659 assert_eq!(
660 content.first(),
661 UserContent::Text(message::Text {
662 text: "What can you help me with?".to_string()
663 })
664 );
665 }
666 _ => panic!("Expected user message"),
667 }
668
669 match assistant_message_array {
671 Message::Assistant { content, .. } => {
672 assert_eq!(
673 content.first(),
674 AssistantContent::Text(message::Text {
675 text: "Hello there, how may I assist you today?".to_string()
676 })
677 );
678 }
679 _ => panic!("Expected assistant message"),
680 }
681 }
682
683 #[test]
684 fn test_message_conversion() {
685 let original_message = message::Message::User {
687 content: OneOrMany::one(message::UserContent::text("Hello")),
688 };
689
690 let mira_value: serde_json::Value = original_message.clone().into();
692
693 let converted_message: Message = mira_value.try_into().unwrap();
695
696 assert_eq!(original_message, converted_message);
697 }
698
699 #[test]
700 fn test_completion_response_conversion() {
701 let mira_response = CompletionResponse::Structured {
702 id: "resp_123".to_string(),
703 object: "chat.completion".to_string(),
704 created: 1234567890,
705 model: "deepseek-r1".to_string(),
706 choices: vec![ChatChoice {
707 message: RawMessage {
708 role: "assistant".to_string(),
709 content: "Test response".to_string(),
710 },
711 finish_reason: Some("stop".to_string()),
712 index: Some(0),
713 }],
714 usage: Some(Usage {
715 prompt_tokens: 10,
716 total_tokens: 20,
717 }),
718 };
719
720 let completion_response: completion::CompletionResponse<CompletionResponse> =
721 mira_response.try_into().unwrap();
722
723 assert_eq!(
724 completion_response.choice.first(),
725 completion::AssistantContent::text("Test response")
726 );
727 }
728}