1use std::path::Path;
4
5use crate::model::TlModel;
6use crate::tensor::TlTensor;
7use crate::train::predict_linfa;
8
9pub fn predict(model: &TlModel, input: &TlTensor) -> Result<TlTensor, String> {
11 match model {
12 TlModel::Onnx { path, .. } => predict_onnx(path, input),
13 TlModel::Linfa { .. } => predict_linfa(model, input),
14 TlModel::LlmEndpoint { .. } => {
15 Err("Cannot use predict() on an LLM endpoint. Use ai_complete() instead.".to_string())
16 }
17 }
18}
19
20pub fn predict_onnx(model_path: &Path, input: &TlTensor) -> Result<TlTensor, String> {
22 use ort::session::Session;
23
24 let mut session = Session::builder()
25 .and_then(|mut b| b.commit_from_file(model_path))
26 .map_err(|e| format!("Failed to load ONNX model: {e}"))?;
27
28 let shape = input.shape();
29 let flat_data: Vec<f32> = input.to_vec().iter().map(|&x| x as f32).collect();
30 let shape_i64: Vec<i64> = shape.iter().map(|&s| s as i64).collect();
31
32 let input_value = ort::value::Tensor::from_array((shape_i64, flat_data))
34 .map_err(|e| format!("Failed to create ORT tensor: {e}"))?;
35
36 let outputs = session
37 .run(ort::inputs![input_value])
38 .map_err(|e| format!("ONNX inference failed: {e}"))?;
39
40 let output = outputs.values().next().ok_or("No output from ONNX model")?;
42
43 let (out_shape_ref, out_flat) = output
44 .try_extract_tensor::<f32>()
45 .map_err(|e| format!("Failed to extract output: {e}"))?;
46
47 let out_shape: Vec<usize> = out_shape_ref.iter().map(|&d| d as usize).collect();
48 let out_data: Vec<f64> = out_flat.iter().map(|&x| x as f64).collect();
49
50 TlTensor::from_vec(out_data, &out_shape)
51}
52
53pub fn predict_batch(
55 model: &TlModel,
56 input: &TlTensor,
57 batch_size: usize,
58) -> Result<TlTensor, String> {
59 let shape = input.shape();
60 if shape.len() < 2 {
61 return predict(model, input);
62 }
63
64 let n_samples = shape[0];
65 if n_samples <= batch_size {
66 return predict(model, input);
67 }
68
69 let mut all_preds = Vec::new();
70
71 for start in (0..n_samples).step_by(batch_size) {
72 let end = (start + batch_size).min(n_samples);
73 let batch = input.slice(start, end)?;
74 let preds = predict(model, &batch)?;
75 all_preds.extend(preds.to_vec());
76 }
77
78 Ok(TlTensor::from_list(all_preds))
79}