rphi 0.3.2

A simple interface for Phi models
Documentation
use crate::model::PhiModel;
use crate::InferenceSettings;
pub use crate::Phi;
use crate::PhiBuilder;
use crate::Task;
use kalosm_common::ModelLoadingProgress;
use kalosm_language_model::*;
use kalosm_streams::text_stream::ChannelTextStream;
use std::ops::Deref;
use std::sync::Arc;
use std::sync::Mutex;
use tokenizers::Tokenizer;

#[async_trait::async_trait]
impl ModelBuilder for PhiBuilder {
    type Model = Phi;

    async fn start_with_loading_handler(
        self,
        handler: impl FnMut(ModelLoadingProgress) + Send + Sync + 'static,
    ) -> anyhow::Result<Self::Model> {
        self.build_with_loading_handler(handler).await
    }

    fn requires_download(&self) -> bool {
        !self.source.model.downloaded() || !self.source.tokenizer.downloaded()
    }
}

#[async_trait::async_trait]
impl Model for Phi {
    type TextStream = ChannelTextStream;
    type SyncModel = PhiModel;

    fn tokenizer(&self) -> Arc<Tokenizer> {
        self.get_tokenizer()
    }

    fn run_sync_raw(
        &self,
        f: Box<
            dyn for<'a> FnOnce(
                    &'a mut Self::SyncModel,
                )
                    -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + 'a>>
                + Send,
        >,
    ) -> anyhow::Result<()> {
        match self.task_sender.send(Task::RunSync { callback: f }) {
            Ok(_) => Ok(()),
            Err(_) => Err(anyhow::anyhow!("Failed to send task to Phi thread")),
        }
    }

    async fn stream_text_inner(
        &self,
        prompt: &str,
        generation_parameters: GenerationParameters,
    ) -> anyhow::Result<Self::TextStream> {
        let max_length = generation_parameters.max_length();
        self.run(
            InferenceSettings::new(prompt)
                .with_sample_len(max_length as usize)
                .with_stop_on(generation_parameters.stop_on().map(|s| s.to_string())),
            Arc::new(Mutex::new(generation_parameters.sampler())),
        )
        .map(Into::into)
    }

    async fn stream_text_with_sampler(
        &self,
        prompt: &str,
        max_tokens: Option<u32>,
        stop_on: Option<&str>,
        sampler: Arc<Mutex<dyn llm_samplers::prelude::Sampler>>,
    ) -> anyhow::Result<Self::TextStream> {
        let max_length = max_tokens.unwrap_or(64);
        self.run(
            InferenceSettings::new(prompt)
                .with_sample_len(max_length as usize)
                .with_stop_on(stop_on.map(|s| s.to_string())),
            sampler,
        )
        .map(Into::into)
    }

    fn chat_markers(&self) -> Option<ChatMarkers> {
        self.chat_markers.deref().clone()
    }
}