kalosm_llama/
language_model.rs

1use kalosm_language_model::{
2    CreateDefaultChatConstraintsForType, CreateDefaultCompletionConstraintsForType,
3    CreateTextCompletionSession, GenerationParameters, ModelBuilder, StructuredTextCompletionModel,
4    TextCompletionModel,
5};
6use kalosm_model_types::ModelLoadingProgress;
7use kalosm_sample::{ArcParser, CreateParserState, Parse, Parser, ParserExt};
8use llm_samplers::types::Sampler;
9use std::any::Any;
10use std::future::Future;
11
12use crate::model::LlamaModelError;
13use crate::structured::generate_structured;
14pub use crate::Llama;
15use crate::LlamaBuilder;
16use crate::{
17    InferenceSettings, LlamaSession, LlamaSourceError, StructuredGenerationTask, Task,
18    UnstructuredGenerationTask,
19};
20
21impl ModelBuilder for LlamaBuilder {
22    type Model = Llama;
23    type Error = LlamaSourceError;
24
25    async fn start_with_loading_handler(
26        self,
27        handler: impl FnMut(ModelLoadingProgress) + Send + Sync + 'static,
28    ) -> Result<Self::Model, Self::Error> {
29        self.build_with_loading_handler(handler).await
30    }
31
32    fn requires_download(&self) -> bool {
33        let cache = &self.source.cache;
34        !cache.exists(&self.source.model)
35            || self
36                .source
37                .tokenizer
38                .as_ref()
39                .filter(|t| cache.exists(t))
40                .is_none()
41    }
42}
43
44impl CreateTextCompletionSession for Llama {
45    type Session = LlamaSession;
46    type Error = LlamaModelError;
47
48    fn new_session(&self) -> Result<Self::Session, Self::Error> {
49        Ok(LlamaSession::new(&self.config))
50    }
51}
52
53impl<S: Sampler + 'static> TextCompletionModel<S> for Llama {
54    fn stream_text_with_callback<'a>(
55        &'a self,
56        session: &'a mut Self::Session,
57        text: &str,
58        sampler: S,
59        on_token: impl FnMut(String) -> Result<(), Self::Error> + Send + Sync + 'static,
60    ) -> impl Future<Output = Result<(), Self::Error>> + Send + 'a {
61        let text = text.to_string();
62        async move {
63            let (tx, rx) = tokio::sync::oneshot::channel();
64            let (max_tokens, stop_on, seed) =
65                match (&sampler as &dyn Any).downcast_ref::<GenerationParameters>() {
66                    Some(sampler) => (
67                        sampler.max_length(),
68                        sampler.stop_on().map(|s| s.to_string()),
69                        sampler.seed(),
70                    ),
71                    None => (u32::MAX, None, None),
72                };
73            let sampler = std::sync::Arc::new(std::sync::Mutex::new(sampler));
74            let on_token = Box::new(on_token);
75            self.task_sender
76                .send(Task::UnstructuredGeneration(UnstructuredGenerationTask {
77                    settings: InferenceSettings::new(
78                        text,
79                        session.clone(),
80                        sampler,
81                        max_tokens,
82                        stop_on,
83                        seed,
84                    ),
85                    on_token,
86                    finished: tx,
87                }))
88                .map_err(|_| LlamaModelError::ModelStopped)?;
89
90            rx.await.map_err(|_| LlamaModelError::ModelStopped)??;
91
92            Ok(())
93        }
94    }
95}
96
97impl<T: Parse + 'static> CreateDefaultChatConstraintsForType<T> for Llama {
98    type DefaultConstraints = ArcParser<T>;
99
100    fn create_default_constraints() -> Self::DefaultConstraints {
101        T::new_parser().boxed()
102    }
103}
104
105impl<T: Parse + 'static> CreateDefaultCompletionConstraintsForType<T> for Llama {
106    type DefaultConstraints = ArcParser<T>;
107
108    fn create_default_constraints() -> Self::DefaultConstraints {
109        T::new_parser().boxed()
110    }
111}
112
113impl<S, Constraints> StructuredTextCompletionModel<Constraints, S> for Llama
114where
115    <Constraints as Parser>::Output: Send,
116    Constraints: CreateParserState + Send + 'static,
117    S: Sampler + 'static,
118{
119    fn stream_text_with_callback_and_parser<'a>(
120        &'a self,
121        session: &'a mut Self::Session,
122        text: &str,
123        sampler: S,
124        parser: Constraints,
125        on_token: impl FnMut(String) -> Result<(), Self::Error> + Send + Sync + 'static,
126    ) -> impl Future<Output = Result<Constraints::Output, Self::Error>> + Send + 'a {
127        let text = text.to_string();
128        let mut session = session.clone();
129        async {
130            let (tx, rx) = tokio::sync::oneshot::channel();
131            let seed = match (&sampler as &dyn Any).downcast_ref::<GenerationParameters>() {
132                Some(sampler) => sampler.seed(),
133                None => None,
134            };
135            let sampler = std::sync::Arc::new(std::sync::Mutex::new(sampler));
136            let on_token = Box::new(on_token);
137            self.task_sender
138                .send(Task::StructuredGeneration(StructuredGenerationTask {
139                    runner: Box::new(move |model| {
140                        let parser_state = parser.create_parser_state();
141                        let result = generate_structured(
142                            text,
143                            model,
144                            &mut session,
145                            parser,
146                            parser_state,
147                            sampler,
148                            on_token,
149                            Some(64),
150                            seed,
151                        );
152                        _ = tx.send(result);
153                    }),
154                }))
155                .map_err(|_| LlamaModelError::ModelStopped)?;
156
157            let result = rx.await.map_err(|_| LlamaModelError::ModelStopped)??;
158
159            Ok(result)
160        }
161    }
162}