1use std::fs;
2use std::io::{BufWriter, Write};
3use std::net::SocketAddr;
4use std::path::{Path, PathBuf};
5
6use anyhow::{Context, Result};
7
8use crate::codegen::CodeGenerator;
9use crate::model::{DataType, Model, TensorData};
10
11const EMBEDDED_WEIGHTS_FILE: &str = "embedded_weights.bin";
12
13pub struct NativeCodegen;
14
15impl CodeGenerator for NativeCodegen {
16 fn generate(&self, model: &Model, output_dir: &Path, listen: SocketAddr) -> Result<PathBuf> {
17 let project_dir = output_dir.join("modelc_build");
18 let src_dir = project_dir.join("src");
19
20 fs::create_dir_all(&src_dir).with_context(|| "failed to create build directory")?;
21
22 let mut names: Vec<&String> = model.tensors.keys().collect();
23 names.sort();
24
25 let blob_path = project_dir.join(EMBEDDED_WEIGHTS_FILE);
26 let blob_file = fs::File::create(&blob_path)
27 .with_context(|| "failed to create embedded weight blob")?;
28 let mut writer = BufWriter::new(blob_file);
29
30 let mut tensor_loads = String::new();
31 let mut blob_offset = 0usize;
32
33 for name in &names {
34 let tensor = model.tensors.get(*name).expect("tensor key mismatch");
35 let offset = blob_offset;
36 let byte_len = tensor.data.len();
37 writer
38 .write_all(&tensor.data)
39 .with_context(|| format!("writing tensor {name} into blob"))?;
40
41 blob_offset = blob_offset.saturating_add(byte_len);
42
43 let shape_fmt = format!("{:?}", tensor.shape);
44 let dtype_size = tensor.dtype.byte_size();
45 tensor_loads.push_str(&format!(
46 " ({:?}, TensorMeta {{ shape: &{shape_fmt}, dtype_size: {dtype_size}, byte_offset: {offset}, byte_len: {byte_len} }}),\n",
47 name
48 ));
49 }
50
51 writer
52 .flush()
53 .with_context(|| "flushing embedded weight blob")?;
54
55 let cargo_toml = generate_cargo_toml();
56 let listen_str = listen.to_string();
57 let optional_helpers = emit_mlp_helpers(model);
58 let forward_fn = emit_forward_fn(model);
59 let main_rs = generate_main_rs(
60 model,
61 EMBEDDED_WEIGHTS_FILE,
62 &tensor_loads,
63 &listen_str,
64 optional_helpers.trim_end(),
65 forward_fn.trim_end(),
66 );
67
68 fs::write(project_dir.join("Cargo.toml"), cargo_toml)?;
69 fs::write(src_dir.join("main.rs"), main_rs)?;
70
71 Ok(project_dir)
72 }
73}
74
75fn generate_cargo_toml() -> String {
76 r#"[package]
77name = "model-serve"
78version = "0.1.0"
79edition = "2021"
80
81[dependencies]
82axum = "0.7"
83tokio = { version = "1", features = ["full"] }
84serde = { version = "1", features = ["derive"] }
85serde_json = "1"
86
87[profile.release]
88opt-level = 3
89lto = true
90strip = true
91"#
92 .to_string()
93}
94
95fn escape_rust_string_literal(s: &str) -> String {
96 s.replace('\\', "\\\\").replace('"', "\\\"")
97}
98
99fn generate_main_rs(
100 model: &Model,
101 weights_file: &str,
102 tensor_loads: &str,
103 listen: &str,
104 optional_helpers: &str,
105 forward_fn: &str,
106) -> String {
107 let model_name_esc = escape_rust_string_literal(&model.name);
108 let arch_esc = escape_rust_string_literal(&model.architecture);
109 let listen_esc = escape_rust_string_literal(listen);
110
111 let total_params = model.total_params();
112 let total_bytes = model.total_bytes();
113
114 format!(
115 r##"use std::collections::HashMap;
116use std::sync::Arc;
117
118use axum::{{Json, Router, extract::State, routing::{{get, post}}}};
119use serde::{{Deserialize, Serialize}};
120
121struct TensorMeta {{
122 shape: &'static [usize],
123 dtype_size: usize,
124 byte_offset: usize,
125 byte_len: usize,
126}}
127
128struct AppState {{
129 weights: &'static [u8],
130 tensors: HashMap<&'static str, TensorMeta>,
131}}
132
133#[derive(Deserialize)]
134struct InferRequest {{
135 input: Vec<f32>,
136}}
137
138#[derive(Serialize)]
139struct InferResponse {{
140 output: Vec<f32>,
141}}
142
143#[derive(Serialize)]
144struct ModelInfo {{
145 name: &'static str,
146 architecture: &'static str,
147 total_params: usize,
148 total_bytes: usize,
149 tensors: Vec<String>,
150}}
151
152const MODEL_NAME: &str = "{model_name_esc}";
153const MODEL_ARCHITECTURE: &str = "{arch_esc}";
154
155#[tokio::main]
156async fn main() {{
157 let weights: &'static [u8] = include_bytes!("../{weights_file}");
158
159 let mut tensors = HashMap::new();
160 let tensor_defs: Vec<(&str, TensorMeta)> = vec![
161{tensor_loads} ];
162 for (name, meta) in tensor_defs {{
163 tensors.insert(name, meta);
164 }}
165
166 let state = Arc::new(AppState {{ weights, tensors }});
167
168 let app = Router::new()
169 .route("/infer", post(infer))
170 .route("/info", get(model_info))
171 .with_state(state);
172
173 let addr = "{listen_esc}"
174 .parse::<std::net::SocketAddr>()
175 .expect("embedded listen address");
176
177 let total_mb = {total_bytes} as f64 / (1024.0 * 1024.0);
178 eprintln!(
179 "model-serve: listening on http://{{}}\n model: {{}}\n architecture: {{}}\n parameters: {total_params}\n weight blob: {total_bytes} bytes (~{{:.4}} MB)",
180 addr,
181 MODEL_NAME,
182 MODEL_ARCHITECTURE,
183 total_mb,
184 );
185
186 let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
187 axum::serve(listener, app).await.unwrap();
188}}
189
190async fn infer(
191 State(state): State<Arc<AppState>>,
192 Json(req): Json<InferRequest>,
193) -> Json<InferResponse> {{
194 let result = forward(&state, &req.input);
195 Json(InferResponse {{ output: result }})
196}}
197
198async fn model_info(State(state): State<Arc<AppState>>) -> Json<ModelInfo> {{
199 Json(ModelInfo {{
200 name: MODEL_NAME,
201 architecture: MODEL_ARCHITECTURE,
202 total_params: {total_params},
203 total_bytes: {total_bytes},
204 tensors: state.tensors.keys().map(|k| k.to_string()).collect(),
205 }})
206}}
207
208{optional_helpers}
209
210{forward_fn}
211"##
212 )
213}
214
215fn emit_mlp_helpers(model: &Model) -> String {
216 if infer_mlp_plan(model).is_none() {
217 return String::new();
218 }
219
220 r#"fn decode_f32(state: &AppState, name: &'static str) -> Vec<f32> {
221 let meta = state.tensors.get(name).expect("tensor lookup");
222 let slice = &state.weights[meta.byte_offset..meta.byte_offset + meta.byte_len];
223 decode_f32_le(slice).expect("decoder expects fp32 payloads for this server")
224}
225
226fn decode_f32_le(bytes: &[u8]) -> Option<Vec<f32>> {
227 if bytes.len() % 4 != 0 {
228 return None;
229 }
230 let mut out = Vec::with_capacity(bytes.len() / 4);
231 for chunk in bytes.chunks_exact(4) {
232 out.push(f32::from_le_bytes(chunk.try_into().ok()?));
233 }
234 Some(out)
235}
236
237fn matmul_bias(
238 state: &AppState,
239 weight_key: &'static str,
240 bias_key: &'static str,
241 x: &[f32],
242) -> Vec<f32> {
243 let w_meta = state
244 .tensors
245 .get(weight_key)
246 .unwrap_or_else(|| panic!("missing tensor `{}`", weight_key));
247 assert_eq!(
248 w_meta.dtype_size, 4,
249 "`{}` must be fp32 for this emitted server",
250 weight_key
251 );
252 assert_eq!(
253 w_meta.shape.len(),
254 2,
255 "`{}`: expected dense matrix",
256 weight_key,
257 );
258
259 let rows = w_meta.shape[0];
260 let cols = w_meta.shape[1];
261 assert_eq!(
262 cols,
263 x.len(),
264 "`{}` gemv mismatch: cols {cols}, input {}",
265 weight_key,
266 x.len(),
267 );
268
269 let w_flat = decode_f32(state, weight_key);
270 assert_eq!(w_flat.len(), rows.checked_mul(cols).expect("sizes"));
271
272 let bias = decode_f32(state, bias_key);
273 assert_eq!(bias.len(), rows);
274
275 let mut out = vec![0f32; rows];
276 for r in 0..rows {
277 let mut acc = bias[r];
278 let row = &w_flat[r * cols..(r + 1) * cols];
279 for (wv, xv) in row.iter().zip(x.iter()) {
280 acc += *wv * *xv;
281 }
282 out[r] = acc;
283 }
284 out
285}
286
287fn relu_inplace(xs: &mut [f32]) {
288 for v in xs {
289 *v = v.max(0.0);
290 }
291}"#
292 .trim_end()
293 .to_string()
294}
295
296fn emit_forward_fn(model: &Model) -> String {
297 if let Some(plan) = infer_mlp_plan(model) {
298 build_forward_from_plan(plan)
299 } else {
300 placeholder_forward()
301 }
302}
303
304fn placeholder_forward() -> String {
305 r#"fn forward(_state: &AppState, input: &[f32]) -> Vec<f32> {
306 input.to_vec()
307}"#
308 .to_string()
309}
310
311fn build_forward_from_plan(plan: Vec<(String, String)>) -> String {
312 debug_assert!(!plan.is_empty());
313
314 if plan.len() == 1 {
315 let (w, b) = &plan[0];
316 return finalize_forward(&format!(
317 " matmul_bias(state, {:?}, {:?}, input)",
318 w.as_str(),
319 b.as_str()
320 ));
321 }
322
323 let (w0, b0) = &plan[0];
324 let mut body = format!(
325 " let mut cur = matmul_bias(state, {:?}, {:?}, input);\n relu_inplace(&mut cur);\n",
326 w0.as_str(),
327 b0.as_str(),
328 );
329
330 let last_global = plan.len() - 1;
331 for (global_idx, (w, b)) in plan.iter().enumerate().skip(1) {
332 body.push_str(&format!(
333 " cur = matmul_bias(state, {:?}, {:?}, &cur);\n",
334 w.as_str(),
335 b.as_str(),
336 ));
337
338 if global_idx != last_global {
339 body.push_str(" relu_inplace(&mut cur);\n");
340 }
341 }
342
343 body.push_str(" cur");
344 finalize_forward(&body)
345}
346
347fn finalize_forward(body: &str) -> String {
348 format!(
349 "fn forward(state: &AppState, input: &[f32]) -> Vec<f32> {{\n{}\n}}",
350 body.trim_end()
351 )
352}
353
354fn infer_mlp_plan(model: &Model) -> Option<Vec<(String, String)>> {
355 if model.architecture != "mlp" {
356 return None;
357 }
358
359 layered_mlp_pairs(model).or_else(|| singleton_affine_pair(model))
360}
361
362fn singleton_affine_pair(model: &Model) -> Option<Vec<(String, String)>> {
363 validate_affine_pair(model, "weight", "bias")?;
364 Some(vec![("weight".to_string(), "bias".to_string())])
365}
366
367fn layered_mlp_pairs(model: &Model) -> Option<Vec<(String, String)>> {
368 let mut ids: Vec<u32> = model
369 .tensors
370 .keys()
371 .filter_map(|key| parse_layer_suffix(key.as_str()))
372 .collect();
373 if ids.is_empty() {
374 return None;
375 }
376
377 ids.sort_unstable();
378 ids.dedup();
379
380 if !ids.windows(2).all(|pair| pair[1] == pair[0] + 1) {
381 return None;
382 }
383
384 let mut seq = Vec::new();
385 let mut prev_out_rows: Option<usize> = None;
386
387 for id in ids {
388 let weight_name = format!("layer{id}.weight");
389 let bias_name = format!("layer{id}.bias");
390 let (rows, cols) = affine_pair_shape(model, &weight_name, &bias_name)?;
391
392 if let Some(out_prev) = prev_out_rows
393 && out_prev != cols
394 {
395 return None;
396 }
397
398 seq.push((weight_name, bias_name));
399 prev_out_rows = Some(rows);
400 }
401
402 Some(seq)
403}
404
405fn affine_pair_shape(model: &Model, weight_name: &str, bias_name: &str) -> Option<(usize, usize)> {
406 validate_affine_pair(model, weight_name, bias_name)?;
407 let w = model.tensors.get(weight_name)?;
408 Some((*w.shape.first()?, *w.shape.get(1)?))
409}
410
411fn validate_affine_pair<'m>(
412 model: &'m Model,
413 weight_name: &str,
414 bias_name: &str,
415) -> Option<&'m TensorData> {
416 let w = model.tensors.get(weight_name)?;
417 let b = model.tensors.get(bias_name)?;
418
419 if w.dtype != DataType::F32 || b.dtype != DataType::F32 {
420 return None;
421 }
422
423 let rows = *w.shape.first()?;
424 if w.shape.len() != 2 || b.shape.len() != 1 {
425 return None;
426 }
427
428 (b.shape[0] == rows).then_some(w)
429}
430
431fn parse_layer_suffix(name: &str) -> Option<u32> {
432 let tail = name.strip_prefix("layer")?;
433 let (idx, suf) = tail.split_once('.')?;
434 if suf != "weight" {
435 return None;
436 }
437
438 idx.parse::<u32>().ok()
439}