use alloc::{borrow::Cow, sync::Arc};
use core::{
fmt,
ptr::{self, NonNull}
};
use std::path::Path;
use ort_sys::c_char;
use super::{Checkpoint, Optimizer, training_api};
use crate::{
AsPointer,
environment::Environment,
error::{Error, Result},
memory::Allocator,
ortsys,
session::{RunOptions, SessionInputValue, SessionInputs, SessionOutputs, builder::SessionBuilder},
util::{char_p_to_string, with_cstr_ptr_array},
value::{IntoTensorElementType, Tensor, Value}
};
#[derive(Debug)]
pub struct Trainer {
ptr: NonNull<ort_sys::OrtTrainingSession>,
train_output_names: Vec<String>,
eval_output_names: Vec<String>,
train_input_names: Vec<String>,
eval_input_names: Vec<String>,
ckpt: Checkpoint,
_allocator: Allocator,
_environment: Arc<Environment>
}
impl Trainer {
pub fn new(
session_options: SessionBuilder,
allocator: Allocator,
ckpt: Checkpoint,
training_model_path: impl AsRef<Path>,
eval_model_path: impl AsRef<Path>,
optimizer_model_path: impl AsRef<Path>
) -> Result<Self> {
let training_model_path = crate::util::path_to_os_char(training_model_path);
let eval_model_path = crate::util::path_to_os_char(eval_model_path);
let optimizer_model_path = crate::util::path_to_os_char(optimizer_model_path);
let mut ptr: *mut ort_sys::OrtTrainingSession = ptr::null_mut();
ortsys![@training:
unsafe CreateTrainingSession(
session_options.environment.ptr(),
session_options.ptr(),
ckpt.ptr.as_ptr(),
training_model_path.as_ptr(),
eval_model_path.as_ptr(),
optimizer_model_path.as_ptr(),
&mut ptr
)?;
nonNull(ptr)
];
Self::new_inner(ptr, &session_options.environment, allocator, ckpt)
}
pub fn new_from_artifacts(
session_options: SessionBuilder,
allocator: Allocator,
base_dir: impl AsRef<Path>,
override_ckpt: Option<Checkpoint>
) -> Result<Self> {
let base_dir = base_dir.as_ref();
let ckpt = if let Some(ckpt) = override_ckpt {
ckpt
} else {
Checkpoint::load(base_dir.join("checkpoint"))?
};
Self::new(
session_options,
allocator,
ckpt,
base_dir.join("training_model.onnx"),
base_dir.join("eval_model.onnx"),
base_dir.join("optimizer_model.onnx")
)
}
pub fn new_from_memory(
session_options: SessionBuilder,
allocator: Allocator,
ckpt: Checkpoint,
training_model: &[u8],
eval_model: &[u8],
optimizer_model: &[u8]
) -> Result<Self> {
let mut ptr: *mut ort_sys::OrtTrainingSession = ptr::null_mut();
ortsys![@training:
unsafe CreateTrainingSessionFromBuffer(
session_options.environment.ptr(),
session_options.ptr(),
ckpt.ptr.as_ptr(),
training_model.as_ptr().cast(),
training_model.len(),
eval_model.as_ptr().cast(),
eval_model.len(),
optimizer_model.as_ptr().cast(),
optimizer_model.len(),
&mut ptr
)?;
nonNull(ptr)
];
Self::new_inner(ptr, &session_options.environment, allocator, ckpt)
}
fn new_inner(ptr: NonNull<ort_sys::OrtTrainingSession>, environment: &Arc<Environment>, allocator: Allocator, ckpt: Checkpoint) -> Result<Self> {
let api = training_api()?;
let train_output_names =
extract_io_names(ptr, &allocator, api.TrainingSessionGetTrainingModelOutputCount, api.TrainingSessionGetTrainingModelOutputName)?;
let eval_output_names = extract_io_names(ptr, &allocator, api.TrainingSessionGetEvalModelOutputCount, api.TrainingSessionGetEvalModelOutputName)?;
let train_input_names = extract_io_names(ptr, &allocator, api.TrainingSessionGetTrainingModelInputCount, api.TrainingSessionGetTrainingModelInputName)?;
let eval_input_names = extract_io_names(ptr, &allocator, api.TrainingSessionGetEvalModelInputCount, api.TrainingSessionGetEvalModelInputName)?;
crate::logging::create!(Trainer, ptr);
Ok(Self {
ptr,
train_output_names,
train_input_names,
eval_output_names,
eval_input_names,
ckpt,
_allocator: allocator,
_environment: Arc::clone(environment)
})
}
pub fn step<'s, 'i1, 'v1: 'i1, 'i2: 'i1, 'v2: 'i2 + 'i1, const N1: usize, const N2: usize>(
&'s self,
inputs: impl Into<SessionInputs<'i1, 'v1, N1>>,
labels: impl Into<SessionInputs<'i2, 'v2, N2>>
) -> Result<SessionOutputs<'s>> {
match inputs.into() {
SessionInputs::ValueSlice(input_values) => match labels.into() {
SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels).map(Some), None),
SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()).map(Some), None),
SessionInputs::ValueMap(labels) => {
let labels = mapped_inputs(&self.train_input_names, &labels);
self.step_inner(input_values.iter().map(Some).chain(labels), None)
}
},
SessionInputs::ValueArray(input_values) => match labels.into() {
SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels).map(Some), None),
SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()).map(Some), None),
SessionInputs::ValueMap(labels) => {
let labels = mapped_inputs(&self.train_input_names, &labels);
self.step_inner(input_values.iter().map(Some).chain(labels), None)
}
},
SessionInputs::ValueMap(input_values) => {
let input_values = mapped_inputs(&self.train_input_names, &input_values);
match labels.into() {
SessionInputs::ValueSlice(labels) => self.step_inner(input_values.into_iter().chain(labels.iter().map(Some)), None),
SessionInputs::ValueArray(labels) => self.step_inner(input_values.into_iter().chain(labels.iter().map(Some)), None),
SessionInputs::ValueMap(labels) => {
let labels = mapped_inputs(&self.train_input_names, &labels);
self.step_inner(input_values.into_iter().chain(labels), None)
}
}
}
}
}
fn step_inner<'r, 's: 'r, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>(
&'s self,
input_values: impl Iterator<Item = Option<&'i1 SessionInputValue<'v1>>>,
run_options: Option<&'r RunOptions>
) -> Result<SessionOutputs<'r>> {
let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![ptr::null_mut(); self.train_output_names.len()];
let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|v| v.map_or(ptr::null(), |v| v.ptr())).collect();
let run_options_ptr = if let Some(run_options) = &run_options { run_options.ptr() } else { ptr::null() };
ortsys![@training: unsafe TrainStep(self.ptr.as_ptr(), run_options_ptr, input_ort_values.len(), input_ort_values.as_ptr(), output_tensor_ptrs.len(), output_tensor_ptrs.as_mut_ptr())?];
let outputs = output_tensor_ptrs
.into_iter()
.map(|tensor_ptr| unsafe {
Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), None)
})
.collect();
Ok(SessionOutputs::new(self.train_output_names.iter().map(String::as_str).collect(), outputs))
}
pub fn eval_step<'s, 'i1, 'v1: 'i1, 'i2: 'i1, 'v2: 'i2 + 'i1, const N1: usize, const N2: usize>(
&'s self,
inputs: impl Into<SessionInputs<'i1, 'v1, N1>>,
labels: impl Into<SessionInputs<'i2, 'v2, N2>>
) -> Result<SessionOutputs<'s>> {
match inputs.into() {
SessionInputs::ValueSlice(input_values) => match labels.into() {
SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels).map(Some), None),
SessionInputs::ValueArray(labels) => self.eval_step_inner(input_values.iter().chain(labels.iter()).map(Some), None),
SessionInputs::ValueMap(labels) => {
let labels = mapped_inputs(&self.eval_input_names, &labels);
self.eval_step_inner(input_values.iter().map(Some).chain(labels), None)
}
},
SessionInputs::ValueArray(input_values) => match labels.into() {
SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels).map(Some), None),
SessionInputs::ValueArray(labels) => self.eval_step_inner(input_values.iter().chain(labels.iter()).map(Some), None),
SessionInputs::ValueMap(labels) => {
let labels = mapped_inputs(&self.eval_input_names, &labels);
self.eval_step_inner(input_values.iter().map(Some).chain(labels), None)
}
},
SessionInputs::ValueMap(input_values) => {
let input_values = mapped_inputs(&self.eval_input_names, &input_values);
match labels.into() {
SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.into_iter().chain(labels.iter().map(Some)), None),
SessionInputs::ValueArray(labels) => self.eval_step_inner(input_values.into_iter().chain(labels.iter().map(Some)), None),
SessionInputs::ValueMap(labels) => {
let labels = mapped_inputs(&self.eval_input_names, &labels);
self.eval_step_inner(input_values.into_iter().chain(labels), None)
}
}
}
}
}
fn eval_step_inner<'r, 's: 'r, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>(
&'s self,
input_values: impl Iterator<Item = Option<&'i1 SessionInputValue<'v1>>>,
run_options: Option<&'r RunOptions>
) -> Result<SessionOutputs<'r>> {
let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![ptr::null_mut(); self.eval_output_names.len()];
let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|v| v.map_or(ptr::null(), |v| v.ptr())).collect();
let run_options_ptr = if let Some(run_options) = &run_options { run_options.ptr() } else { ptr::null() };
ortsys![@training: unsafe EvalStep(self.ptr.as_ptr(), run_options_ptr, input_ort_values.len(), input_ort_values.as_ptr(), output_tensor_ptrs.len(), output_tensor_ptrs.as_mut_ptr())?];
let outputs = output_tensor_ptrs
.into_iter()
.map(|tensor_ptr| unsafe {
Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), None)
})
.collect();
Ok(SessionOutputs::new(self.eval_output_names.iter().map(String::as_str).collect(), outputs))
}
pub fn export<O: AsRef<str>>(&self, out_path: impl AsRef<Path>, output_names: impl AsRef<[O]>) -> Result<()> {
let out_path = crate::util::path_to_os_char(out_path);
with_cstr_ptr_array(output_names.as_ref(), &|output_name_ptrs| {
ortsys![@training: unsafe ExportModelForInferencing(self.ptr.as_ptr(), out_path.as_ptr(), output_name_ptrs.len(), output_name_ptrs.as_ptr())?];
Ok(())
})?;
Ok(())
}
pub fn num_params(&self, trainable_only: bool) -> Result<usize> {
let mut out = 0;
ortsys![@training: unsafe GetParametersSize(self.ptr.as_ptr(), &mut out, trainable_only)?];
Ok(out)
}
pub fn copy_parameters_to<T: IntoTensorElementType + fmt::Debug>(&self, value: &mut Tensor<T>, trainable_only: bool) -> Result<()> {
ortsys![@training: unsafe CopyParametersToBuffer(self.ptr.as_ptr(), value.ptr_mut(), trainable_only)?];
Ok(())
}
pub fn copy_parameters_from<T: IntoTensorElementType + fmt::Debug>(&mut self, value: &Tensor<T>, trainable_only: bool) -> Result<()> {
ortsys![@training: unsafe CopyBufferToParameters(self.ptr.as_ptr(), value.ptr().cast_mut(), trainable_only)?];
Ok(())
}
pub fn optimizer(&self) -> Optimizer<'_> {
Optimizer::new(self.ptr)
}
pub fn checkpoint(&self) -> &Checkpoint {
&self.ckpt
}
}
impl AsPointer for Trainer {
type Sys = ort_sys::OrtTrainingSession;
fn ptr(&self) -> *const Self::Sys {
self.ptr.as_ptr()
}
}
impl Drop for Trainer {
fn drop(&mut self) {
crate::logging::drop!(Trainer, self.ptr);
ortsys![@training: unsafe ReleaseTrainingSession(self.ptr.as_ptr())];
}
}
fn mapped_inputs<'v, 'a>(input_names: &[String], values: &'a [(Cow<'_, str>, SessionInputValue<'v>)]) -> Vec<Option<&'a SessionInputValue<'v>>> {
let mut out = Vec::with_capacity(input_names.len());
'o: for want_name in input_names {
for (name, value) in values {
if want_name == name {
out.push(Some(value));
continue 'o;
}
}
out.push(None);
}
out
}
fn extract_io_names(
ptr: NonNull<ort_sys::OrtTrainingSession>,
allocator: &Allocator,
get_count: unsafe extern "system" fn(sess: *const ort_sys::OrtTrainingSession, out: *mut usize) -> ort_sys::OrtStatusPtr,
get_name: unsafe extern "system" fn(
sess: *const ort_sys::OrtTrainingSession,
index: usize,
allocator: *mut ort_sys::OrtAllocator,
output: *mut *const c_char
) -> ort_sys::OrtStatusPtr
) -> Result<Vec<String>> {
let mut count = 0;
unsafe { Error::result_from_status(get_count(ptr.as_ptr(), &mut count)) }?;
(0..count)
.map(|i| {
let mut name_bytes: *const c_char = ptr::null();
unsafe { Error::result_from_status(get_name(ptr.as_ptr(), i, allocator.ptr().cast_mut(), &mut name_bytes)) }?;
let name = match char_p_to_string(name_bytes) {
Ok(name) => name,
Err(e) => {
unsafe { allocator.free(name_bytes.cast_mut()) };
return Err(e);
}
};
unsafe { allocator.free(name_bytes.cast_mut()) };
Ok(name)
})
.collect::<Result<Vec<String>>>()
}