kalosm_llama/
language_model.rs1use kalosm_language_model::{
2 CreateDefaultChatConstraintsForType, CreateDefaultCompletionConstraintsForType,
3 CreateTextCompletionSession, GenerationParameters, ModelBuilder, StructuredTextCompletionModel,
4 TextCompletionModel,
5};
6use kalosm_model_types::ModelLoadingProgress;
7use kalosm_sample::{ArcParser, CreateParserState, Parse, Parser, ParserExt};
8use llm_samplers::types::Sampler;
9use std::any::Any;
10use std::future::Future;
11
12use crate::model::LlamaModelError;
13use crate::structured::generate_structured;
14pub use crate::Llama;
15use crate::LlamaBuilder;
16use crate::{
17 InferenceSettings, LlamaSession, LlamaSourceError, StructuredGenerationTask, Task,
18 UnstructuredGenerationTask,
19};
20
21impl ModelBuilder for LlamaBuilder {
22 type Model = Llama;
23 type Error = LlamaSourceError;
24
25 async fn start_with_loading_handler(
26 self,
27 handler: impl FnMut(ModelLoadingProgress) + Send + Sync + 'static,
28 ) -> Result<Self::Model, Self::Error> {
29 self.build_with_loading_handler(handler).await
30 }
31
32 fn requires_download(&self) -> bool {
33 let cache = &self.source.cache;
34 !cache.exists(&self.source.model)
35 || self
36 .source
37 .tokenizer
38 .as_ref()
39 .filter(|t| cache.exists(t))
40 .is_none()
41 }
42}
43
44impl CreateTextCompletionSession for Llama {
45 type Session = LlamaSession;
46 type Error = LlamaModelError;
47
48 fn new_session(&self) -> Result<Self::Session, Self::Error> {
49 Ok(LlamaSession::new(&self.config))
50 }
51}
52
53impl<S: Sampler + 'static> TextCompletionModel<S> for Llama {
54 fn stream_text_with_callback<'a>(
55 &'a self,
56 session: &'a mut Self::Session,
57 text: &str,
58 sampler: S,
59 on_token: impl FnMut(String) -> Result<(), Self::Error> + Send + Sync + 'static,
60 ) -> impl Future<Output = Result<(), Self::Error>> + Send + 'a {
61 let text = text.to_string();
62 async move {
63 let (tx, rx) = tokio::sync::oneshot::channel();
64 let (max_tokens, stop_on, seed) =
65 match (&sampler as &dyn Any).downcast_ref::<GenerationParameters>() {
66 Some(sampler) => (
67 sampler.max_length(),
68 sampler.stop_on().map(|s| s.to_string()),
69 sampler.seed(),
70 ),
71 None => (u32::MAX, None, None),
72 };
73 let sampler = std::sync::Arc::new(std::sync::Mutex::new(sampler));
74 let on_token = Box::new(on_token);
75 self.task_sender
76 .send(Task::UnstructuredGeneration(UnstructuredGenerationTask {
77 settings: InferenceSettings::new(
78 text,
79 session.clone(),
80 sampler,
81 max_tokens,
82 stop_on,
83 seed,
84 ),
85 on_token,
86 finished: tx,
87 }))
88 .map_err(|_| LlamaModelError::ModelStopped)?;
89
90 rx.await.map_err(|_| LlamaModelError::ModelStopped)??;
91
92 Ok(())
93 }
94 }
95}
96
97impl<T: Parse + 'static> CreateDefaultChatConstraintsForType<T> for Llama {
98 type DefaultConstraints = ArcParser<T>;
99
100 fn create_default_constraints() -> Self::DefaultConstraints {
101 T::new_parser().boxed()
102 }
103}
104
105impl<T: Parse + 'static> CreateDefaultCompletionConstraintsForType<T> for Llama {
106 type DefaultConstraints = ArcParser<T>;
107
108 fn create_default_constraints() -> Self::DefaultConstraints {
109 T::new_parser().boxed()
110 }
111}
112
113impl<S, Constraints> StructuredTextCompletionModel<Constraints, S> for Llama
114where
115 <Constraints as Parser>::Output: Send,
116 Constraints: CreateParserState + Send + 'static,
117 S: Sampler + 'static,
118{
119 fn stream_text_with_callback_and_parser<'a>(
120 &'a self,
121 session: &'a mut Self::Session,
122 text: &str,
123 sampler: S,
124 parser: Constraints,
125 on_token: impl FnMut(String) -> Result<(), Self::Error> + Send + Sync + 'static,
126 ) -> impl Future<Output = Result<Constraints::Output, Self::Error>> + Send + 'a {
127 let text = text.to_string();
128 let mut session = session.clone();
129 async {
130 let (tx, rx) = tokio::sync::oneshot::channel();
131 let seed = match (&sampler as &dyn Any).downcast_ref::<GenerationParameters>() {
132 Some(sampler) => sampler.seed(),
133 None => None,
134 };
135 let sampler = std::sync::Arc::new(std::sync::Mutex::new(sampler));
136 let on_token = Box::new(on_token);
137 self.task_sender
138 .send(Task::StructuredGeneration(StructuredGenerationTask {
139 runner: Box::new(move |model| {
140 let parser_state = parser.create_parser_state();
141 let result = generate_structured(
142 text,
143 model,
144 &mut session,
145 parser,
146 parser_state,
147 sampler,
148 on_token,
149 Some(64),
150 seed,
151 );
152 _ = tx.send(result);
153 }),
154 }))
155 .map_err(|_| LlamaModelError::ModelStopped)?;
156
157 let result = rx.await.map_err(|_| LlamaModelError::ModelStopped)??;
158
159 Ok(result)
160 }
161 }
162}