1use std::{
2 collections::HashMap,
3 path::{Path, PathBuf},
4 sync::Arc,
5};
6
7use anyhow::{bail, Result};
8use derivative::Derivative;
9use flume::{Receiver, Sender};
10use futures::future::join_all;
11use half::f16;
12use itertools::Itertools;
13use memmap2::Mmap;
14use reload::{AdapterOption, BnfOption, Precision};
15use safetensors::SafeTensors;
16use salvo::oapi::ToSchema;
17use serde::{de::DeserializeSeed, Deserialize, Serialize};
18use tokio::{
19 fs::File,
20 io::{AsyncReadExt, BufReader},
21 sync::RwLock,
22 time::Duration,
23};
24use web_rwkv::{
25 context::{Context, ContextBuilder, ContextError, InstanceExt},
26 runtime::{
27 infer::Rnn,
28 loader::{Loader, Lora, LoraBlend, Reader},
29 model::{Bundle, ContextAutoLimits, ModelBuilder, ModelInfo, ModelVersion, Quant, State},
30 v4, v5, v6, v7, Runtime, TokioRuntime,
31 },
32 tensor::{serialization::Seed, TensorCpu, TensorError, TensorInit},
33 tokenizer::Tokenizer,
34 wgpu::{Backends, PowerPreference},
35};
36
37use crate::{run::GenerateContext, sampler::Sampler};
38
39pub mod reload;
40pub mod run;
41pub mod sampler;
42
43pub const MAX_TOKENS: usize = usize::MAX;
44
45#[derive(Debug)]
46pub enum Token {
47 Start,
48 Content(String),
49 Stop(FinishReason, TokenCounter),
50 Embed(Vec<f32>, [usize; 4]),
51 Choose(Vec<f32>),
52 Done,
53}
54
55#[derive(Debug, Default, Clone, Serialize, Deserialize, ToSchema)]
56pub struct TokenCounter {
57 #[serde(alias = "prompt_tokens")]
58 pub prompt: usize,
59 #[serde(alias = "completion_tokens")]
60 pub completion: usize,
61 #[serde(alias = "total_tokens")]
62 pub total: usize,
63 pub duration: Duration,
64}
65
66#[derive(Debug, Default, Clone, Copy, Serialize, ToSchema)]
67#[serde(rename_all = "snake_case")]
68#[allow(dead_code)]
69pub enum FinishReason {
70 Stop,
72 Length,
74 ContentFilter,
76 #[default]
78 #[serde(untagged)]
79 Null,
80}
81
82#[derive(Debug, Clone)]
83pub enum ThreadRequest {
84 Adapter(Sender<AdapterList>),
86 Info(Sender<RuntimeInfo>),
88 Generate {
90 request: Box<GenerateRequest>,
91 tokenizer: Arc<Tokenizer>,
92 sender: Sender<Token>,
93 },
94 Reload {
96 request: Box<ReloadRequest>,
97 sender: Option<Sender<bool>>,
98 },
99 Unload,
101 Save {
103 request: SaveRequest,
104 sender: Sender<bool>,
105 },
106}
107
108#[derive(Default)]
109pub enum Environment {
110 Loaded {
111 info: RuntimeInfo,
112 runtime: Arc<dyn Runtime<Rnn> + Send + Sync>,
113 model: Arc<dyn ModelSerialize + Send + Sync>,
114 sender: Sender<GenerateContext>,
115 },
116 #[default]
117 None,
118}
119
120#[derive(Derivative, Clone)]
121#[derivative(Debug)]
122pub struct RuntimeInfo {
123 pub reload: Arc<ReloadRequest>,
124 pub info: ModelInfo,
125 pub states: Vec<InitState>,
126 pub tokenizer: Arc<Tokenizer>,
127}
128
129struct Model<M>(M);
130
131pub trait ModelSerialize {
132 fn serialize(&self, file: std::fs::File) -> Result<()>;
133}
134
135impl<M: Serialize> ModelSerialize for Model<M> {
136 fn serialize(&self, file: std::fs::File) -> Result<()> {
137 use cbor4ii::{core::enc::Write, serde::Serializer};
138 use std::{fs::File, io::Write as _};
139
140 struct FileWriter(File);
141 impl Write for FileWriter {
142 type Error = std::io::Error;
143 fn push(&mut self, input: &[u8]) -> Result<(), Self::Error> {
144 self.0.write_all(input)
145 }
146 }
147
148 let file = FileWriter(file);
149 let mut serializer = Serializer::new(file);
150 self.0.serialize(&mut serializer)?;
151
152 Ok(())
153 }
154}
155
156#[derive(Debug, Default, Clone)]
157pub struct AdapterList(pub Vec<String>);
158
159#[derive(Debug, Default, Clone)]
160pub enum GenerateKind {
161 #[default]
163 None,
164 State,
166 Choose {
168 choices: Vec<String>,
169 calibrate: bool,
170 },
171}
172
173#[derive(Clone, Derivative)]
174#[derivative(Debug, Default)]
175pub struct GenerateRequest {
176 pub prompt: String,
178 pub model_text: String,
180 pub max_tokens: usize,
182 pub stop: Vec<String>,
184 pub bias: Arc<HashMap<u32, f32>>,
186 pub bnf_schema: Option<String>,
188 #[derivative(
190 Debug = "ignore",
191 Default(value = "Arc::new(RwLock::new(sampler::nucleus::NucleusSampler::default()))")
192 )]
193 pub sampler: Arc<RwLock<dyn Sampler + Send + Sync>>,
194 pub kind: GenerateKind,
196 pub state: Arc<InputState>,
198}
199
200#[derive(Debug, Derivative, Clone, Serialize, Deserialize, ToSchema)]
201#[derivative(Default)]
202#[serde(default)]
203pub struct ReloadRequest {
204 #[salvo(schema(value_type = String))]
206 pub model_path: PathBuf,
207 pub lora: Vec<reload::Lora>,
209 pub state: Vec<reload::State>,
211 pub quant: usize,
213 #[salvo(schema(value_type = sealed::Quant))]
215 pub quant_type: Quant,
216 pub precision: Precision,
218 #[derivative(Default(value = "128"))]
220 pub token_chunk_size: usize,
221 #[derivative(Default(value = "8"))]
223 pub max_batch: usize,
224 #[salvo(schema(value_type = String))]
226 pub tokenizer_path: PathBuf,
227 pub bnf: BnfOption,
229 pub adapter: AdapterOption,
231}
232
233#[derive(Debug, Default, Clone, Serialize, Deserialize, ToSchema)]
234#[serde(default)]
235pub struct SaveRequest {
236 #[serde(alias = "model_path")]
238 #[salvo(schema(value_type = String))]
239 pub path: PathBuf,
240}
241
242#[derive(Debug, Deserialize)]
243struct Prefab {
244 info: ModelInfo,
245}
246
247#[derive(Debug, Clone, Copy)]
248enum LoadType {
249 SafeTensors,
250 Prefab,
251}
252
253#[derive(
254 Derivative, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, ToSchema,
255)]
256#[derivative(Debug = "transparent")]
257#[serde(transparent)]
258pub struct StateId(uuid::Uuid);
259
260impl StateId {
261 pub fn new() -> Self {
262 Self(uuid::Uuid::new_v4())
263 }
264}
265
266#[derive(Debug, Default, Clone, Serialize, Deserialize, ToSchema)]
267pub struct StateValue {
268 pub name: String,
269 pub id: StateId,
270 pub data: Vec<f32>,
271 pub shape: [usize; 4],
272}
273
274#[derive(Debug, Default, Clone, Serialize, Deserialize, ToSchema)]
275pub struct StateFile {
276 pub name: String,
277 pub id: StateId,
278 #[salvo(schema(value_type = String))]
279 pub path: PathBuf,
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
284#[serde(untagged)]
285pub enum InputState {
286 Key(StateId),
287 Value(StateValue),
288 File(StateFile),
289}
290
291impl Default for InputState {
292 fn default() -> Self {
293 Self::Key(Default::default())
294 }
295}
296
297impl InputState {
298 pub fn id(&self) -> StateId {
299 match self {
300 InputState::Key(id) => *id,
301 InputState::Value(value) => value.id,
302 InputState::File(file) => file.id,
303 }
304 }
305}
306
307#[derive(Derivative, Clone, Serialize, Deserialize)]
308#[derivative(Debug)]
309pub struct InitState {
310 pub name: String,
311 pub id: StateId,
312 pub default: bool,
313 #[derivative(Debug = "ignore")]
314 pub data: TensorCpu<f32>,
315}
316
317impl TryFrom<StateValue> for InitState {
318 type Error = TensorError;
319
320 fn try_from(
321 StateValue {
322 name,
323 id,
324 data,
325 shape,
326 }: StateValue,
327 ) -> Result<Self, Self::Error> {
328 let default = false;
329 let data = TensorCpu::from_data(shape, data)?;
330 Ok(Self {
331 name,
332 id,
333 default,
334 data,
335 })
336 }
337}
338
339fn list_adapters() -> AdapterList {
340 let backends = Backends::all();
341 let instance = web_rwkv::wgpu::Instance::default();
342 let list = instance
343 .enumerate_adapters(backends)
344 .into_iter()
345 .map(|adapter| adapter.get_info())
346 .map(|info| format!("{} ({:?})", info.name, info.backend))
347 .collect();
348 AdapterList(list)
349}
350
351async fn create_context(adapter: AdapterOption, info: &ModelInfo) -> Result<Context> {
352 let backends = Backends::all();
353 let instance = web_rwkv::wgpu::Instance::default();
354 let adapter = match adapter {
355 AdapterOption::Auto => instance.adapter(PowerPreference::HighPerformance).await,
356 AdapterOption::Economical => instance.adapter(PowerPreference::LowPower).await,
357 AdapterOption::Manual(selection) => Ok(instance
358 .enumerate_adapters(backends)
359 .into_iter()
360 .nth(selection)
361 .ok_or(ContextError::RequestAdapterFailed)?),
362 }?;
363 let context = ContextBuilder::new(adapter)
364 .auto_limits(info)
365 .build()
366 .await?;
367 Ok(context)
368}
369
370async fn load_tokenizer(path: impl AsRef<Path>) -> Result<Tokenizer> {
371 let file = File::open(path).await?;
372 let mut reader = BufReader::new(file);
373 let mut contents = String::new();
374 reader.read_to_string(&mut contents).await?;
375 Ok(Tokenizer::new(&contents)?)
376}
377
378async fn load_model_state<R: Reader>(
379 context: &Context,
380 info: &ModelInfo,
381 model: R,
382) -> Result<TensorCpu<f32>> {
383 match info.version {
384 ModelVersion::V4 => bail!("v4 does not support init state yet"),
385 ModelVersion::V5 => Ok(v5::read_state(context, info, model).await?),
386 ModelVersion::V6 => Ok(v6::read_state(context, info, model).await?),
387 ModelVersion::V7 => Ok(v7::read_state(context, info, model).await?),
388 }
389}
390
391async fn load_runtime(
392 context: &Context,
393 info: &ModelInfo,
394 request: &ReloadRequest,
395 load: LoadType,
396) -> Result<(
397 Vec<InitState>,
398 Arc<dyn Runtime<Rnn> + Send + Sync>,
399 Arc<dyn State + Send + Sync>,
400 Arc<dyn ModelSerialize + Send + Sync>,
401)> {
402 let ReloadRequest {
403 model_path,
404 lora,
405 state,
406 quant,
407 quant_type,
408 precision,
409 max_batch,
410 ..
411 } = request.clone();
412
413 let mut states = Vec::with_capacity(state.len());
414 for state in state.into_iter() {
415 let reload::State {
416 path,
417 name,
418 id,
419 default,
420 } = state;
421 let name = match name {
422 Some(name) => name,
423 None => match path.file_name() {
424 Some(name) => name.to_string_lossy().to_string(),
425 None => continue,
426 },
427 };
428 let file = File::open(path).await?;
429 let data = unsafe { Mmap::map(&file) }?;
430 let model = SafeTensors::deserialize(&data)?;
431 match load_model_state(context, info, model).await {
432 Ok(data) => {
433 let state = InitState {
434 name,
435 id,
436 data,
437 default,
438 };
439 log::info!("{:#?}", state);
440 states.push(state);
441 }
442 Err(err) => log::warn!("initial state not loaded: {}", err),
443 }
444 }
445
446 let file = File::open(model_path).await?;
447 let data = unsafe { Mmap::map(&file) }?;
448
449 match load {
450 LoadType::SafeTensors => {
451 let model = SafeTensors::deserialize(&data)?;
452 if let Ok(data) = load_model_state(context, info, model).await {
453 let name = "internal".into();
454 let id = StateId::new();
455 let state = InitState {
456 name,
457 id,
458 data,
459 default: true,
460 };
461 states.push(state);
462 }
463
464 let model = SafeTensors::deserialize(&data)?;
465 let quant = (0..quant).map(|layer| (layer, quant_type)).collect();
466 let lora: Vec<Result<_>> = join_all(lora.iter().map(|lora| async move {
467 let reload::Lora { path, alpha } = lora;
468 let file = File::open(path).await?;
469 let data = unsafe { Mmap::map(&file)? };
470 let blend = LoraBlend::full(*alpha);
471 Ok((data, blend))
472 }))
473 .await;
474 let lora: Vec<_> = lora.into_iter().try_collect()?;
475 let lora: Vec<_> = lora
476 .iter()
477 .map(|(data, blend)| -> Result<_> {
478 let data = SafeTensors::deserialize(data)?;
479 let blend = blend.clone();
480 Ok(Lora { data, blend })
481 })
482 .try_collect()?;
483
484 let builder = ModelBuilder::new(context, model).quant(quant);
485 let builder = lora.into_iter().fold(builder, |builder, x| builder.lora(x));
486
487 macro_rules! match_safe_tensors {
488 (($v:expr, $p:expr), { $(($version:path, $precision:path, $model:ty, $build:ident, $bundle:ty)),+ }) => {
489 match ($v, $p) {
490 $(
491 ($version, $precision) => {
492 let model = builder.$build().await?;
493 let bundle = <$bundle>::new(model, max_batch);
494 let state = Arc::new(bundle.state());
495 let model = Arc::new(Model(bundle.model()));
496 let runtime = Arc::new(TokioRuntime::<Rnn>::new(bundle).await);
497 Ok((states, runtime, state, model))
498 }
499 )+
500 }
501 }
502 }
503 match_safe_tensors!(
504 (info.version, precision),
505 {
506 (ModelVersion::V4, Precision::Fp16, v4::Model, build_v4, v4::Bundle::<f16>),
507 (ModelVersion::V5, Precision::Fp16, v5::Model, build_v5, v5::Bundle::<f16>),
508 (ModelVersion::V6, Precision::Fp16, v6::Model, build_v6, v6::Bundle::<f16>),
509 (ModelVersion::V7, Precision::Fp16, v7::Model, build_v7, v7::Bundle::<f16>),
510 (ModelVersion::V4, Precision::Fp32, v4::Model, build_v4, v4::Bundle::<f32>),
511 (ModelVersion::V5, Precision::Fp32, v5::Model, build_v5, v5::Bundle::<f32>),
512 (ModelVersion::V6, Precision::Fp32, v6::Model, build_v6, v6::Bundle::<f32>),
513 (ModelVersion::V7, Precision::Fp32, v7::Model, build_v7, v7::Bundle::<f32>)
514 }
515 )
516 }
517 LoadType::Prefab => {
518 use cbor4ii::{core::utils::SliceReader, serde::Deserializer};
519
520 let reader = SliceReader::new(&data);
521 let mut deserializer = Deserializer::new(reader);
522
523 macro_rules! match_prefab {
524 (($v:expr, $p:expr), { $(($version:path, $precision:path, $model:ty, $bundle:ty)),+ }) => {
525 match ($v, $p) {
526 $(
527 ($version, $precision) => {
528 let seed: Seed<_, $model> = Seed::new(context);
529 let model = seed.deserialize(&mut deserializer)?;
530 let bundle = <$bundle>::new(model, max_batch);
531 let state = Arc::new(bundle.state());
532 let model = Arc::new(Model(bundle.model()));
533 let runtime = Arc::new(TokioRuntime::<Rnn>::new(bundle).await);
534 Ok((states, runtime, state, model))
535 }
536 )+
537 }
538 }
539 }
540 match_prefab!(
541 (info.version, precision),
542 {
543 (ModelVersion::V4, Precision::Fp16, v4::Model, v4::Bundle::<f16>),
544 (ModelVersion::V5, Precision::Fp16, v5::Model, v5::Bundle::<f16>),
545 (ModelVersion::V6, Precision::Fp16, v6::Model, v6::Bundle::<f16>),
546 (ModelVersion::V7, Precision::Fp16, v7::Model, v7::Bundle::<f16>),
547 (ModelVersion::V4, Precision::Fp32, v4::Model, v4::Bundle::<f32>),
548 (ModelVersion::V5, Precision::Fp32, v5::Model, v5::Bundle::<f32>),
549 (ModelVersion::V6, Precision::Fp32, v6::Model, v6::Bundle::<f32>),
550 (ModelVersion::V7, Precision::Fp32, v7::Model, v7::Bundle::<f32>)
551 }
552 )
553 }
554 }
555}
556
557async fn process(env: Arc<RwLock<Environment>>, request: ThreadRequest) -> Result<()> {
558 match request {
559 ThreadRequest::Adapter(sender) => {
560 let _ = sender.send(list_adapters());
561 }
562 ThreadRequest::Info(sender) => {
563 let env = env.read().await;
564 if let Environment::Loaded { info, .. } = &*env {
565 let _ = sender.send(info.clone());
566 }
567 }
568 ThreadRequest::Generate {
569 request,
570 tokenizer,
571 sender,
572 } => {
573 let context = GenerateContext::new(*request, sender, &tokenizer).await?;
574 let env = env.read().await;
575 if let Environment::Loaded { sender, .. } = &*env {
576 let _ = sender.send(context);
577 }
578 }
579 ThreadRequest::Reload { request, sender } => {
580 let handle = tokio::spawn(async move {
581 let file = File::open(&request.model_path).await?;
582 let data = unsafe { Mmap::map(&file)? };
583 let (info, load) = {
584 let st = SafeTensors::deserialize(&data);
585 let prefab = cbor4ii::serde::from_slice::<Prefab>(&data);
586 match (st, prefab) {
587 (Ok(model), _) => (Loader::info(&model)?, LoadType::SafeTensors),
588 (_, Ok(prefab)) => (prefab.info, LoadType::Prefab),
589 _ => bail!("failed to read model info"),
590 }
591 };
592 log::info!("{:#?}", request);
593 log::info!("{:#?}", info);
594 log::info!("model type: {:?}", load);
595
596 let context = create_context(request.adapter, &info).await?;
597 log::info!("{:#?}", context.adapter.get_info());
598
599 let mut env = env.write().await;
600 let _ = std::mem::take(&mut *env);
601
602 let tokenizer = Arc::new(load_tokenizer(&request.tokenizer_path).await?);
603
604 let (states, runtime, state, model) =
605 load_runtime(&context, &info, &request, load).await?;
606
607 let reload = Arc::new(*request);
608 let info = RuntimeInfo {
609 reload,
610 info,
611 states,
612 tokenizer,
613 };
614
615 let sender = {
616 let runtime = Arc::downgrade(&runtime);
617 let (sender, receiver) = flume::unbounded();
618 tokio::spawn(crate::run::run(
619 context,
620 runtime,
621 state,
622 receiver,
623 info.clone(),
624 ));
625 sender
626 };
627
628 log::info!("model loaded");
629
630 let _ = std::mem::replace(
631 &mut *env,
632 Environment::Loaded {
633 info,
634 runtime,
635 model,
636 sender,
637 },
638 );
639 Ok(())
640 });
641
642 if let Some(sender) = sender {
643 let _ = match handle.await? {
644 Ok(_) => sender.send(true),
645 Err(err) => {
646 log::error!("[reload] error: {err:#?}");
647 sender.send(false)
648 }
649 };
650 }
651 }
652 ThreadRequest::Unload => {
653 let mut env = env.write().await;
654 let _ = std::mem::take(&mut *env);
655 log::info!("model unloaded");
656 }
657 ThreadRequest::Save { request, sender } => {
658 let env = env.read().await;
659 if let Environment::Loaded { model, .. } = &*env {
660 log::info!("serializing model into {:?}", &request.path);
661 let model = model.clone();
662 let handle = tokio::task::spawn_blocking(move || {
663 let file = std::fs::File::create(request.path)?;
664 model.serialize(file)
665 });
666 drop(env);
667
668 let _ = match handle.await? {
669 Ok(_) => sender.send(true),
670 Err(err) => {
671 log::error!("[save] error: {err:#?}");
672 sender.send(false)
673 }
674 };
675 }
676 }
677 };
678 Ok(())
679}
680
681pub async fn serve(receiver: Receiver<ThreadRequest>) {
682 let env: Arc<RwLock<Environment>> = Default::default();
683 while let Ok(request) = receiver.recv_async().await {
684 let future = process(env.clone(), request);
685 tokio::spawn(future);
686 }
687}
688
689#[allow(dead_code)]
690mod sealed {
691 use salvo::oapi::ToSchema;
692
693 #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, ToSchema)]
694 pub enum Quant {
695 #[default]
697 None,
698 Int8,
700 NF4,
702 SF4,
704 }
705
706 #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, ToSchema)]
707 pub enum EmbedDevice {
708 #[default]
709 Cpu,
710 Gpu,
711 }
712}