1use std::{
2 fmt::{Debug, Display},
3 path::PathBuf,
4 sync::{Arc, Mutex},
5};
6use tokio::net::TcpListener;
7
8use crate::{
9 models::common::{detect_text_model_arch, Cache, Config},
10 utils, Args, ModelType, TextModelArch,
11};
12use anyhow::Result;
13use async_trait::async_trait;
14use candle_core::{DType, Device, Tensor};
15use candle_nn::VarBuilder;
16
17#[cfg(feature = "master")]
18pub mod api;
19#[cfg(feature = "master")]
20mod master;
21
22pub mod auth;
23mod client;
24pub mod discovery;
25mod proto;
26pub mod setup;
27mod topology;
28mod worker;
29
30#[cfg(feature = "master")]
31pub use master::*;
32
33pub use client::*;
34pub use proto::*;
35pub use topology::*;
36pub use worker::*;
37
38#[derive(clap::ValueEnum, Clone, Debug, Default)]
40pub enum Mode {
41 #[default]
42 Master,
43 Worker,
44}
45
46#[derive(Clone)]
48pub struct Context {
49 pub args: Args,
50 pub dtype: DType,
51 pub topology: Topology,
52 pub data_path: PathBuf,
53 pub device: Device,
54 pub config: Option<Config>,
55 pub cache: Option<Cache>,
56 pub var_builder: Option<VarBuilder<'static>>,
57 pub text_model_arch: TextModelArch,
59 pub fp8: bool,
61 pub listener_override: Arc<Mutex<Option<TcpListener>>>,
63}
64
65impl Context {
66 pub fn from_args(mut args: Args) -> Result<Self> {
68 let dtype: DType = match args.dtype.as_deref() {
69 Some("f16") => DType::F16,
70 Some("bf16") => DType::BF16,
71 Some("f32") => DType::F32,
72 Some(dtype) => bail!("unsupported dtype {dtype}"),
73 None => DType::F16,
74 };
75
76 let device = utils::get_inference_device(args.cpu, args.device)
77 .map_err(|e| anyhow!("can't attach to device: {:?}", e))?;
78
79 #[cfg(feature = "cuda")]
83 if let Device::Cuda(cuda_dev) = &device {
84 unsafe { cuda_dev.disable_event_tracking(); }
85 }
86
87 log::info!(
88 "[{:?}] dtype={:?} device={:?} mem={}",
89 args.mode,
90 &dtype,
91 &device,
92 human_bytes::human_bytes(memory_stats::memory_stats().map(|m| m.physical_mem).unwrap_or(0) as f64)
93 );
94
95 let data_path = PathBuf::from(&args.model);
96 let data_path = if !data_path.exists() {
97 if utils::hf::looks_like_hf_repo(&args.model) {
98 utils::hf::ensure_model_downloaded(&args.model)?
99 } else {
100 bail!("model path does not exist: {}", data_path.display());
101 }
102 } else {
103 data_path
104 };
105
106 let topology = if let Some(topo) = args.topology_override.take() {
107 topo
109 } else if let Some(path) = &args.topology {
110 Topology::from_path(path, &args.model_type)?
111 } else {
112 log::warn!("no topology file specified, the entire model will be loaded");
113 Topology::new()
114 };
115
116 let mut config: Option<Config> = None;
117 let mut cache: Option<Cache> = None;
118 let mut var_builder: Option<VarBuilder> = None;
119 let mut text_model_arch = args.text_model_arch;
120 let mut fp8 = false;
121
122 if args.model_type == ModelType::TextModel {
123 let config_filename = data_path.join("config.json");
124
125 if text_model_arch == TextModelArch::Auto {
127 let arch_str = detect_text_model_arch(&config_filename).unwrap_or_default();
128 text_model_arch = match arch_str.as_str() {
129 #[cfg(feature = "qwen2")]
130 "Qwen2ForCausalLM" => TextModelArch::Qwen2,
131 #[cfg(feature = "qwen3_5")]
132 "Qwen3_5ForConditionalGeneration" => TextModelArch::Qwen3_5,
133 _ => TextModelArch::Llama,
134 };
135 }
136
137 log::info!("text model architecture: {:?}", text_model_arch);
138
139 let config_internal = match text_model_arch {
140 #[cfg(feature = "qwen2")]
141 TextModelArch::Qwen2 => {
142 crate::models::qwen2::QwenConfig::from_path(&config_filename)?.into_config()
143 }
144 #[cfg(feature = "qwen3_5")]
145 TextModelArch::Qwen3_5 => {
146 crate::models::qwen3_5::Qwen3_5Config::from_path(&config_filename)?.into_config()
147 }
148 #[cfg(feature = "llama")]
149 TextModelArch::Llama => {
150 crate::models::llama3::LlamaConfig::from_path(&config_filename)?.into_config()
151 }
152 _ => {
153 bail!("no text model feature enabled for architecture {:?}", text_model_arch)
156 }
157 };
158
159 let model_tensors_index: PathBuf = data_path.join("model.safetensors.index.json");
160 fp8 = utils::fp8::is_fp8_quantized(&config_filename);
161 if fp8 {
162 log::info!("model uses FP8 quantization — weights will be dequantized at load time");
163 }
164 let is_master = matches!(args.mode, Mode::Master);
165 let my_layers: Vec<String> = if !is_master {
166 topology.all_worker_layers().into_iter().collect()
167 } else {
168 vec![]
169 };
170
171 var_builder = Some(if is_master {
172 let worker_layers = topology.all_worker_layers();
174 if worker_layers.is_empty() {
175 utils::load_var_builder_from_index(
176 model_tensors_index,
177 dtype,
178 device.clone(),
179 fp8,
180 )?
181 } else {
182 utils::load_var_builder_for_local_layers(
183 model_tensors_index,
184 dtype,
185 device.clone(),
186 &worker_layers,
187 fp8,
188 )?
189 }
190 } else if !my_layers.is_empty() {
191 utils::load_var_builder_for_specific_layers(
193 model_tensors_index,
194 dtype,
195 device.clone(),
196 &my_layers,
197 fp8,
198 )?
199 } else {
200 utils::load_var_builder_from_index(
202 model_tensors_index,
203 dtype,
204 device.clone(),
205 fp8,
206 )?
207 });
208 cache = Some(Cache::new(true, dtype, &config_internal, &device)?);
209 config = Some(config_internal);
210 }
211
212 Ok(Context {
213 args,
214 dtype,
215 topology,
216 data_path,
217 device,
218 config,
219 cache,
220 var_builder,
221 text_model_arch,
222 fp8,
223 listener_override: Arc::new(Mutex::new(None)),
224 })
225 }
226}
227
228#[async_trait]
230pub trait Forwarder: Debug + Send + Sync + Display {
231 fn load(name: String, ctx: &Context) -> Result<Box<Self>>
233 where
234 Self: Sized;
235
236 async fn forward(
238 &self,
239 x: &Tensor,
240 index_pos: usize,
241 block_idx: usize,
242 ctx: &mut Context,
243 ) -> Result<Tensor>;
244
245 async fn forward_mut(
247 &mut self,
248 x: &Tensor,
249 index_pos: usize,
250 block_idx: usize,
251 ctx: &mut Context,
252 ) -> Result<Tensor>;
253
254 async fn forward_batch(
256 &mut self,
257 _x: &Tensor,
258 _batch: Vec<(String, usize, usize)>,
259 _ctx: &mut Context,
260 ) -> Result<Tensor> {
261 unimplemented!()
262 }
263
264 async fn goodbye(&mut self) -> Result<()> {
265 Ok(())
266 }
267
268 fn layer_name(&self) -> &str;
270
271 fn ident(&self) -> &str {
273 "local"
274 }
275}