ai00_core/
lib.rs

1use std::{
2    collections::HashMap,
3    path::{Path, PathBuf},
4    sync::Arc,
5};
6
7use anyhow::{bail, Result};
8use derivative::Derivative;
9use flume::{Receiver, Sender};
10use futures::future::join_all;
11use half::f16;
12use itertools::Itertools;
13use memmap2::Mmap;
14use reload::{AdapterOption, BnfOption, Precision};
15use safetensors::SafeTensors;
16use salvo::oapi::ToSchema;
17use serde::{de::DeserializeSeed, Deserialize, Serialize};
18use tokio::{
19    fs::File,
20    io::{AsyncReadExt, BufReader},
21    sync::RwLock,
22    time::Duration,
23};
24use web_rwkv::{
25    context::{Context, ContextBuilder, ContextError, InstanceExt},
26    runtime::{
27        infer::Rnn,
28        loader::{Loader, Lora, LoraBlend, Reader},
29        model::{Bundle, ContextAutoLimits, ModelBuilder, ModelInfo, ModelVersion, Quant, State},
30        v4, v5, v6, v7, Runtime, TokioRuntime,
31    },
32    tensor::{serialization::Seed, TensorCpu, TensorError, TensorInit},
33    tokenizer::Tokenizer,
34    wgpu::{Backends, PowerPreference},
35};
36
37use crate::{run::GenerateContext, sampler::Sampler};
38
39pub mod reload;
40pub mod run;
41pub mod sampler;
42
43pub const MAX_TOKENS: usize = usize::MAX;
44
45#[derive(Debug)]
46pub enum Token {
47    Start,
48    Content(String),
49    Stop(FinishReason, TokenCounter),
50    Embed(Vec<f32>, [usize; 4]),
51    Choose(Vec<f32>),
52    Done,
53}
54
55#[derive(Debug, Default, Clone, Serialize, Deserialize, ToSchema)]
56pub struct TokenCounter {
57    #[serde(alias = "prompt_tokens")]
58    pub prompt: usize,
59    #[serde(alias = "completion_tokens")]
60    pub completion: usize,
61    #[serde(alias = "total_tokens")]
62    pub total: usize,
63    pub duration: Duration,
64}
65
66#[derive(Debug, Default, Clone, Copy, Serialize, ToSchema)]
67#[serde(rename_all = "snake_case")]
68#[allow(dead_code)]
69pub enum FinishReason {
70    /// API returned complete model output.
71    Stop,
72    /// Incomplete model output due to max_tokens parameter or token limit.
73    Length,
74    /// Omitted content due to a flag from our content filters.
75    ContentFilter,
76    /// API response still in progress or incomplete.
77    #[default]
78    #[serde(untagged)]
79    Null,
80}
81
82#[derive(Debug, Clone)]
83pub enum ThreadRequest {
84    /// Acquire a list of current available adapters.
85    Adapter(Sender<AdapterList>),
86    /// Get the current runtime info.
87    Info(Sender<RuntimeInfo>),
88    /// Request the runtime to complement a prompt.
89    Generate {
90        request: Box<GenerateRequest>,
91        tokenizer: Arc<Tokenizer>,
92        sender: Sender<Token>,
93    },
94    /// Reload the runtime with custom config.
95    Reload {
96        request: Box<ReloadRequest>,
97        sender: Option<Sender<bool>>,
98    },
99    /// Unload the runtime.
100    Unload,
101    /// Save the current model with config.
102    Save {
103        request: SaveRequest,
104        sender: Sender<bool>,
105    },
106}
107
108#[derive(Default)]
109pub enum Environment {
110    Loaded {
111        info: RuntimeInfo,
112        runtime: Arc<dyn Runtime<Rnn> + Send + Sync>,
113        model: Arc<dyn ModelSerialize + Send + Sync>,
114        sender: Sender<GenerateContext>,
115    },
116    #[default]
117    None,
118}
119
120#[derive(Derivative, Clone)]
121#[derivative(Debug)]
122pub struct RuntimeInfo {
123    pub reload: Arc<ReloadRequest>,
124    pub info: ModelInfo,
125    pub states: Vec<InitState>,
126    pub tokenizer: Arc<Tokenizer>,
127}
128
129struct Model<M>(M);
130
131pub trait ModelSerialize {
132    fn serialize(&self, file: std::fs::File) -> Result<()>;
133}
134
135impl<M: Serialize> ModelSerialize for Model<M> {
136    fn serialize(&self, file: std::fs::File) -> Result<()> {
137        use cbor4ii::{core::enc::Write, serde::Serializer};
138        use std::{fs::File, io::Write as _};
139
140        struct FileWriter(File);
141        impl Write for FileWriter {
142            type Error = std::io::Error;
143            fn push(&mut self, input: &[u8]) -> Result<(), Self::Error> {
144                self.0.write_all(input)
145            }
146        }
147
148        let file = FileWriter(file);
149        let mut serializer = Serializer::new(file);
150        self.0.serialize(&mut serializer)?;
151
152        Ok(())
153    }
154}
155
156#[derive(Debug, Default, Clone)]
157pub struct AdapterList(pub Vec<String>);
158
159#[derive(Debug, Default, Clone)]
160pub enum GenerateKind {
161    /// Normal text completion.
162    #[default]
163    None,
164    /// The state of input.
165    State,
166    /// Choose options by perplexity.
167    Choose {
168        choices: Vec<String>,
169        calibrate: bool,
170    },
171}
172
173#[derive(Clone, Derivative)]
174#[derivative(Debug, Default)]
175pub struct GenerateRequest {
176    /// The prompt for the model.
177    pub prompt: String,
178    /// All text the model output earlier.
179    pub model_text: String,
180    /// Output token limit.
181    pub max_tokens: usize,
182    /// Stop indicators.
183    pub stop: Vec<String>,
184    /// Bias added to tokens before sampling.
185    pub bias: Arc<HashMap<u32, f32>>,
186    /// Optional BNF schema for formatted generation.
187    pub bnf_schema: Option<String>,
188    /// Sampler parameters.
189    #[derivative(
190        Debug = "ignore",
191        Default(value = "Arc::new(RwLock::new(sampler::nucleus::NucleusSampler::default()))")
192    )]
193    pub sampler: Arc<RwLock<dyn Sampler + Send + Sync>>,
194    /// Generation output kind.
195    pub kind: GenerateKind,
196    /// Initial state.
197    pub state: Arc<InputState>,
198}
199
200#[derive(Debug, Derivative, Clone, Serialize, Deserialize, ToSchema)]
201#[derivative(Default)]
202#[serde(default)]
203pub struct ReloadRequest {
204    /// Path to the model.
205    #[salvo(schema(value_type = String))]
206    pub model_path: PathBuf,
207    /// List of LoRA blended on the model.
208    pub lora: Vec<reload::Lora>,
209    /// Path to the initial state.
210    pub state: Vec<reload::State>,
211    /// Specify layers that needs to be quantized.
212    pub quant: usize,
213    /// Quantization type (`Int8` or `NF4`).
214    #[salvo(schema(value_type = sealed::Quant))]
215    pub quant_type: Quant,
216    /// Precision for intermediate tensors (`Fp16` or `Fp32`).
217    pub precision: Precision,
218    /// Maximum tokens to be processed in parallel at once.
219    #[derivative(Default(value = "128"))]
220    pub token_chunk_size: usize,
221    /// Number of states that are cached on GPU.
222    #[derivative(Default(value = "8"))]
223    pub max_batch: usize,
224    /// Path to the tokenizer.
225    #[salvo(schema(value_type = String))]
226    pub tokenizer_path: PathBuf,
227    /// BNF options.
228    pub bnf: BnfOption,
229    /// Adapter selection.
230    pub adapter: AdapterOption,
231}
232
233#[derive(Debug, Default, Clone, Serialize, Deserialize, ToSchema)]
234#[serde(default)]
235pub struct SaveRequest {
236    /// Path to save the model.
237    #[serde(alias = "model_path")]
238    #[salvo(schema(value_type = String))]
239    pub path: PathBuf,
240}
241
242#[derive(Debug, Deserialize)]
243struct Prefab {
244    info: ModelInfo,
245}
246
247#[derive(Debug, Clone, Copy)]
248enum LoadType {
249    SafeTensors,
250    Prefab,
251}
252
253#[derive(
254    Derivative, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, ToSchema,
255)]
256#[derivative(Debug = "transparent")]
257#[serde(transparent)]
258pub struct StateId(uuid::Uuid);
259
260impl StateId {
261    pub fn new() -> Self {
262        Self(uuid::Uuid::new_v4())
263    }
264}
265
266#[derive(Debug, Default, Clone, Serialize, Deserialize, ToSchema)]
267pub struct StateValue {
268    pub name: String,
269    pub id: StateId,
270    pub data: Vec<f32>,
271    pub shape: [usize; 4],
272}
273
274#[derive(Debug, Default, Clone, Serialize, Deserialize, ToSchema)]
275pub struct StateFile {
276    pub name: String,
277    pub id: StateId,
278    #[salvo(schema(value_type = String))]
279    pub path: PathBuf,
280}
281
282/// State input from the user. Can be a single ID or full state data.
283#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
284#[serde(untagged)]
285pub enum InputState {
286    Key(StateId),
287    Value(StateValue),
288    File(StateFile),
289}
290
291impl Default for InputState {
292    fn default() -> Self {
293        Self::Key(Default::default())
294    }
295}
296
297impl InputState {
298    pub fn id(&self) -> StateId {
299        match self {
300            InputState::Key(id) => *id,
301            InputState::Value(value) => value.id,
302            InputState::File(file) => file.id,
303        }
304    }
305}
306
307#[derive(Derivative, Clone, Serialize, Deserialize)]
308#[derivative(Debug)]
309pub struct InitState {
310    pub name: String,
311    pub id: StateId,
312    pub default: bool,
313    #[derivative(Debug = "ignore")]
314    pub data: TensorCpu<f32>,
315}
316
317impl TryFrom<StateValue> for InitState {
318    type Error = TensorError;
319
320    fn try_from(
321        StateValue {
322            name,
323            id,
324            data,
325            shape,
326        }: StateValue,
327    ) -> Result<Self, Self::Error> {
328        let default = false;
329        let data = TensorCpu::from_data(shape, data)?;
330        Ok(Self {
331            name,
332            id,
333            default,
334            data,
335        })
336    }
337}
338
339fn list_adapters() -> AdapterList {
340    let backends = Backends::all();
341    let instance = web_rwkv::wgpu::Instance::default();
342    let list = instance
343        .enumerate_adapters(backends)
344        .into_iter()
345        .map(|adapter| adapter.get_info())
346        .map(|info| format!("{} ({:?})", info.name, info.backend))
347        .collect();
348    AdapterList(list)
349}
350
351async fn create_context(adapter: AdapterOption, info: &ModelInfo) -> Result<Context> {
352    let backends = Backends::all();
353    let instance = web_rwkv::wgpu::Instance::default();
354    let adapter = match adapter {
355        AdapterOption::Auto => instance.adapter(PowerPreference::HighPerformance).await,
356        AdapterOption::Economical => instance.adapter(PowerPreference::LowPower).await,
357        AdapterOption::Manual(selection) => Ok(instance
358            .enumerate_adapters(backends)
359            .into_iter()
360            .nth(selection)
361            .ok_or(ContextError::RequestAdapterFailed)?),
362    }?;
363    let context = ContextBuilder::new(adapter)
364        .auto_limits(info)
365        .build()
366        .await?;
367    Ok(context)
368}
369
370async fn load_tokenizer(path: impl AsRef<Path>) -> Result<Tokenizer> {
371    let file = File::open(path).await?;
372    let mut reader = BufReader::new(file);
373    let mut contents = String::new();
374    reader.read_to_string(&mut contents).await?;
375    Ok(Tokenizer::new(&contents)?)
376}
377
378async fn load_model_state<R: Reader>(
379    context: &Context,
380    info: &ModelInfo,
381    model: R,
382) -> Result<TensorCpu<f32>> {
383    match info.version {
384        ModelVersion::V4 => bail!("v4 does not support init state yet"),
385        ModelVersion::V5 => Ok(v5::read_state(context, info, model).await?),
386        ModelVersion::V6 => Ok(v6::read_state(context, info, model).await?),
387        ModelVersion::V7 => Ok(v7::read_state(context, info, model).await?),
388    }
389}
390
391async fn load_runtime(
392    context: &Context,
393    info: &ModelInfo,
394    request: &ReloadRequest,
395    load: LoadType,
396) -> Result<(
397    Vec<InitState>,
398    Arc<dyn Runtime<Rnn> + Send + Sync>,
399    Arc<dyn State + Send + Sync>,
400    Arc<dyn ModelSerialize + Send + Sync>,
401)> {
402    let ReloadRequest {
403        model_path,
404        lora,
405        state,
406        quant,
407        quant_type,
408        precision,
409        max_batch,
410        ..
411    } = request.clone();
412
413    let mut states = Vec::with_capacity(state.len());
414    for state in state.into_iter() {
415        let reload::State {
416            path,
417            name,
418            id,
419            default,
420        } = state;
421        let name = match name {
422            Some(name) => name,
423            None => match path.file_name() {
424                Some(name) => name.to_string_lossy().to_string(),
425                None => continue,
426            },
427        };
428        let file = File::open(path).await?;
429        let data = unsafe { Mmap::map(&file) }?;
430        let model = SafeTensors::deserialize(&data)?;
431        match load_model_state(context, info, model).await {
432            Ok(data) => {
433                let state = InitState {
434                    name,
435                    id,
436                    data,
437                    default,
438                };
439                log::info!("{:#?}", state);
440                states.push(state);
441            }
442            Err(err) => log::warn!("initial state not loaded: {}", err),
443        }
444    }
445
446    let file = File::open(model_path).await?;
447    let data = unsafe { Mmap::map(&file) }?;
448
449    match load {
450        LoadType::SafeTensors => {
451            let model = SafeTensors::deserialize(&data)?;
452            if let Ok(data) = load_model_state(context, info, model).await {
453                let name = "internal".into();
454                let id = StateId::new();
455                let state = InitState {
456                    name,
457                    id,
458                    data,
459                    default: true,
460                };
461                states.push(state);
462            }
463
464            let model = SafeTensors::deserialize(&data)?;
465            let quant = (0..quant).map(|layer| (layer, quant_type)).collect();
466            let lora: Vec<Result<_>> = join_all(lora.iter().map(|lora| async move {
467                let reload::Lora { path, alpha } = lora;
468                let file = File::open(path).await?;
469                let data = unsafe { Mmap::map(&file)? };
470                let blend = LoraBlend::full(*alpha);
471                Ok((data, blend))
472            }))
473            .await;
474            let lora: Vec<_> = lora.into_iter().try_collect()?;
475            let lora: Vec<_> = lora
476                .iter()
477                .map(|(data, blend)| -> Result<_> {
478                    let data = SafeTensors::deserialize(data)?;
479                    let blend = blend.clone();
480                    Ok(Lora { data, blend })
481                })
482                .try_collect()?;
483
484            let builder = ModelBuilder::new(context, model).quant(quant);
485            let builder = lora.into_iter().fold(builder, |builder, x| builder.lora(x));
486
487            macro_rules! match_safe_tensors {
488                (($v:expr, $p:expr), { $(($version:path, $precision:path, $model:ty, $build:ident, $bundle:ty)),+ }) => {
489                    match ($v, $p) {
490                        $(
491                            ($version, $precision) => {
492                                let model = builder.$build().await?;
493                                let bundle = <$bundle>::new(model, max_batch);
494                                let state = Arc::new(bundle.state());
495                                let model = Arc::new(Model(bundle.model()));
496                                let runtime = Arc::new(TokioRuntime::<Rnn>::new(bundle).await);
497                                Ok((states, runtime, state, model))
498                            }
499                        )+
500                    }
501                }
502            }
503            match_safe_tensors!(
504                (info.version, precision),
505                {
506                    (ModelVersion::V4, Precision::Fp16, v4::Model, build_v4, v4::Bundle::<f16>),
507                    (ModelVersion::V5, Precision::Fp16, v5::Model, build_v5, v5::Bundle::<f16>),
508                    (ModelVersion::V6, Precision::Fp16, v6::Model, build_v6, v6::Bundle::<f16>),
509                    (ModelVersion::V7, Precision::Fp16, v7::Model, build_v7, v7::Bundle::<f16>),
510                    (ModelVersion::V4, Precision::Fp32, v4::Model, build_v4, v4::Bundle::<f32>),
511                    (ModelVersion::V5, Precision::Fp32, v5::Model, build_v5, v5::Bundle::<f32>),
512                    (ModelVersion::V6, Precision::Fp32, v6::Model, build_v6, v6::Bundle::<f32>),
513                    (ModelVersion::V7, Precision::Fp32, v7::Model, build_v7, v7::Bundle::<f32>)
514                }
515            )
516        }
517        LoadType::Prefab => {
518            use cbor4ii::{core::utils::SliceReader, serde::Deserializer};
519
520            let reader = SliceReader::new(&data);
521            let mut deserializer = Deserializer::new(reader);
522
523            macro_rules! match_prefab {
524                (($v:expr, $p:expr), { $(($version:path, $precision:path, $model:ty, $bundle:ty)),+ }) => {
525                    match ($v, $p) {
526                        $(
527                            ($version, $precision) => {
528                                let seed: Seed<_, $model> = Seed::new(context);
529                                let model = seed.deserialize(&mut deserializer)?;
530                                let bundle = <$bundle>::new(model, max_batch);
531                                let state = Arc::new(bundle.state());
532                                let model = Arc::new(Model(bundle.model()));
533                                let runtime = Arc::new(TokioRuntime::<Rnn>::new(bundle).await);
534                                Ok((states, runtime, state, model))
535                            }
536                        )+
537                    }
538                }
539            }
540            match_prefab!(
541                (info.version, precision),
542                {
543                    (ModelVersion::V4, Precision::Fp16, v4::Model, v4::Bundle::<f16>),
544                    (ModelVersion::V5, Precision::Fp16, v5::Model, v5::Bundle::<f16>),
545                    (ModelVersion::V6, Precision::Fp16, v6::Model, v6::Bundle::<f16>),
546                    (ModelVersion::V7, Precision::Fp16, v7::Model, v7::Bundle::<f16>),
547                    (ModelVersion::V4, Precision::Fp32, v4::Model, v4::Bundle::<f32>),
548                    (ModelVersion::V5, Precision::Fp32, v5::Model, v5::Bundle::<f32>),
549                    (ModelVersion::V6, Precision::Fp32, v6::Model, v6::Bundle::<f32>),
550                    (ModelVersion::V7, Precision::Fp32, v7::Model, v7::Bundle::<f32>)
551                }
552            )
553        }
554    }
555}
556
557async fn process(env: Arc<RwLock<Environment>>, request: ThreadRequest) -> Result<()> {
558    match request {
559        ThreadRequest::Adapter(sender) => {
560            let _ = sender.send(list_adapters());
561        }
562        ThreadRequest::Info(sender) => {
563            let env = env.read().await;
564            if let Environment::Loaded { info, .. } = &*env {
565                let _ = sender.send(info.clone());
566            }
567        }
568        ThreadRequest::Generate {
569            request,
570            tokenizer,
571            sender,
572        } => {
573            let context = GenerateContext::new(*request, sender, &tokenizer).await?;
574            let env = env.read().await;
575            if let Environment::Loaded { sender, .. } = &*env {
576                let _ = sender.send(context);
577            }
578        }
579        ThreadRequest::Reload { request, sender } => {
580            let handle = tokio::spawn(async move {
581                let file = File::open(&request.model_path).await?;
582                let data = unsafe { Mmap::map(&file)? };
583                let (info, load) = {
584                    let st = SafeTensors::deserialize(&data);
585                    let prefab = cbor4ii::serde::from_slice::<Prefab>(&data);
586                    match (st, prefab) {
587                        (Ok(model), _) => (Loader::info(&model)?, LoadType::SafeTensors),
588                        (_, Ok(prefab)) => (prefab.info, LoadType::Prefab),
589                        _ => bail!("failed to read model info"),
590                    }
591                };
592                log::info!("{:#?}", request);
593                log::info!("{:#?}", info);
594                log::info!("model type: {:?}", load);
595
596                let context = create_context(request.adapter, &info).await?;
597                log::info!("{:#?}", context.adapter.get_info());
598
599                let mut env = env.write().await;
600                let _ = std::mem::take(&mut *env);
601
602                let tokenizer = Arc::new(load_tokenizer(&request.tokenizer_path).await?);
603
604                let (states, runtime, state, model) =
605                    load_runtime(&context, &info, &request, load).await?;
606
607                let reload = Arc::new(*request);
608                let info = RuntimeInfo {
609                    reload,
610                    info,
611                    states,
612                    tokenizer,
613                };
614
615                let sender = {
616                    let runtime = Arc::downgrade(&runtime);
617                    let (sender, receiver) = flume::unbounded();
618                    tokio::spawn(crate::run::run(
619                        context,
620                        runtime,
621                        state,
622                        receiver,
623                        info.clone(),
624                    ));
625                    sender
626                };
627
628                log::info!("model loaded");
629
630                let _ = std::mem::replace(
631                    &mut *env,
632                    Environment::Loaded {
633                        info,
634                        runtime,
635                        model,
636                        sender,
637                    },
638                );
639                Ok(())
640            });
641
642            if let Some(sender) = sender {
643                let _ = match handle.await? {
644                    Ok(_) => sender.send(true),
645                    Err(err) => {
646                        log::error!("[reload] error: {err:#?}");
647                        sender.send(false)
648                    }
649                };
650            }
651        }
652        ThreadRequest::Unload => {
653            let mut env = env.write().await;
654            let _ = std::mem::take(&mut *env);
655            log::info!("model unloaded");
656        }
657        ThreadRequest::Save { request, sender } => {
658            let env = env.read().await;
659            if let Environment::Loaded { model, .. } = &*env {
660                log::info!("serializing model into {:?}", &request.path);
661                let model = model.clone();
662                let handle = tokio::task::spawn_blocking(move || {
663                    let file = std::fs::File::create(request.path)?;
664                    model.serialize(file)
665                });
666                drop(env);
667
668                let _ = match handle.await? {
669                    Ok(_) => sender.send(true),
670                    Err(err) => {
671                        log::error!("[save] error: {err:#?}");
672                        sender.send(false)
673                    }
674                };
675            }
676        }
677    };
678    Ok(())
679}
680
681pub async fn serve(receiver: Receiver<ThreadRequest>) {
682    let env: Arc<RwLock<Environment>> = Default::default();
683    while let Ok(request) = receiver.recv_async().await {
684        let future = process(env.clone(), request);
685        tokio::spawn(future);
686    }
687}
688
689#[allow(dead_code)]
690mod sealed {
691    use salvo::oapi::ToSchema;
692
693    #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, ToSchema)]
694    pub enum Quant {
695        /// No quantization.
696        #[default]
697        None,
698        /// Use `Int8` quantization.
699        Int8,
700        /// Use `NF4` quantization.
701        NF4,
702        /// Use `SF4` quantization.
703        SF4,
704    }
705
706    #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, ToSchema)]
707    pub enum EmbedDevice {
708        #[default]
709        Cpu,
710        Gpu,
711    }
712}