chat-mistralrs 0.1.6

Local-inference provider for chat-rs, built on mistral.rs.
Documentation
use async_trait::async_trait;
use chat_core::error::ChatError;
use chat_core::traits::StreamProvider;
use chat_core::types::messages::Messages;
use chat_core::types::messages::content::{Content, RoleEnum};
use chat_core::types::messages::parts::{PartEnum, Parts};
use chat_core::types::messages::text::Text;
use chat_core::types::metadata::Metadata;
use chat_core::types::metadata::usage::Usage as CoreUsage;
use chat_core::types::options::ChatOptions;
use chat_core::types::response::{ChatResponse, StreamEvent};
use chat_core::types::tools::ToolDeclarations;
use futures::StreamExt;
use futures::stream::BoxStream;
use mistralrs::Response as MResponse;

use crate::api::types::request;
use crate::api::types::response::{map_finish_reason, usage_from_m};
use crate::client::MistralRsClient;

#[async_trait]
impl StreamProvider for MistralRsClient {
    async fn stream(
        &mut self,
        messages: &mut Messages,
        tool_declarations: Option<&dyn ToolDeclarations>,
        options: Option<&ChatOptions>,
    ) -> Result<BoxStream<'static, Result<StreamEvent, ChatError>>, ChatError> {
        let tools_present = tool_declarations.is_some();
        let req = request::from_core(messages, options, None, tools_present).map_err(|f| f.err)?;

        let model = self.model.clone();
        let model_id = self.model_id.clone();

        let s = async_stream::try_stream! {
            let mut raw = model
                .stream_chat_request(req)
                .await
                .map_err(|e| ChatError::Provider(format!("mistral.rs stream_chat_request: {e}")))?;

            let mut accumulated = String::new();
            let mut finish_reason: Option<String> = None;
            let mut last_usage: Option<CoreUsage> = None;

            while let Some(item) = raw.next().await {
                match item {
                    MResponse::Chunk(chunk) => {
                        if let Some(choice) = chunk.choices.into_iter().next() {
                            if let Some(piece) = choice.delta.content
                                && !piece.is_empty() {
                                    accumulated.push_str(&piece);
                                    yield StreamEvent::TextChunk(piece);
                                }
                            if let Some(reason) = choice.finish_reason {
                                finish_reason = Some(reason);
                            }
                        }
                        if let Some(u) = chunk.usage {
                            last_usage = Some(usage_from_m(u));
                        }
                    }
                    MResponse::Done(full) => {
                        if let Some(choice) = full.choices.first() {
                            if finish_reason.is_none() && !choice.finish_reason.is_empty() {
                                finish_reason = Some(choice.finish_reason.clone());
                            }
                            if accumulated.is_empty()
                                && let Some(text) = &choice.message.content {
                                    accumulated = text.clone();
                                    yield StreamEvent::TextChunk(text.clone());
                                }
                        }
                        last_usage = Some(usage_from_m(full.usage));
                        break;
                    }
                    MResponse::CompletionChunk(_) | MResponse::CompletionDone(_) => {
                    }
                    MResponse::ModelError(msg, _) | MResponse::CompletionModelError(msg, _) => {
                        Err(ChatError::Provider(format!("mistral.rs model error: {msg}")))?;
                    }
                    MResponse::InternalError(e) | MResponse::ValidationError(e) => {
                        Err(ChatError::Provider(format!("mistral.rs internal error: {e}")))?;
                    }
                    _ => {}
                }
            }

            let response = ChatResponse {
                metadata: Some(Metadata {
                    model_slug: Some(model_id.clone()),
                    usage: last_usage.unwrap_or_default(),
                    ..Default::default()
                }),
                content: Content {
                    role: RoleEnum::Model,
                    parts: Parts(vec![PartEnum::Text(Text::new(accumulated))]),
                    complete_reason: map_finish_reason(finish_reason.as_deref()),
                },
            };
            yield StreamEvent::Done(response);
        };

        Ok(s.boxed())
    }
}