use alloc::sync::Arc;
use core::ptr::{self, NonNull};
#[cfg(feature = "std")]
use std::path::Path;
use crate::{AsPointer, Result, memory::Allocator, ortsys};
#[derive(Debug)]
pub(crate) struct AdapterInner {
ptr: NonNull<ort_sys::OrtLoraAdapter>
}
impl AsPointer for AdapterInner {
type Sys = ort_sys::OrtLoraAdapter;
fn ptr(&self) -> *const Self::Sys {
self.ptr.as_ptr()
}
}
impl Drop for AdapterInner {
fn drop(&mut self) {
ortsys![unsafe ReleaseLoraAdapter(self.ptr.as_ptr())];
crate::logging::drop!(Adapter, self.ptr);
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "api-20")))]
#[derive(Debug, Clone)]
pub struct Adapter {
pub(crate) inner: Arc<AdapterInner>
}
impl Adapter {
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
pub fn from_file(path: impl AsRef<Path>, allocator: Option<&Allocator>) -> Result<Self> {
let path = crate::util::path_to_os_char(path);
let allocator_ptr = allocator.map(|c| c.ptr().cast_mut()).unwrap_or_else(ptr::null_mut);
let mut ptr = ptr::null_mut();
ortsys![unsafe CreateLoraAdapter(path.as_ptr(), allocator_ptr, &mut ptr)?; nonNull(ptr)];
crate::logging::create!(Adapter, ptr);
Ok(Adapter {
inner: Arc::new(AdapterInner { ptr })
})
}
pub fn from_memory(bytes: &[u8], allocator: Option<&Allocator>) -> Result<Self> {
let allocator_ptr = allocator.map(|c| c.ptr().cast_mut()).unwrap_or_else(ptr::null_mut);
let mut ptr = ptr::null_mut();
ortsys![unsafe CreateLoraAdapterFromArray(bytes.as_ptr().cast(), bytes.len(), allocator_ptr, &mut ptr)?; nonNull(ptr)];
crate::logging::create!(Adapter, ptr);
Ok(Adapter {
inner: Arc::new(AdapterInner { ptr })
})
}
}
impl AsPointer for Adapter {
type Sys = ort_sys::OrtLoraAdapter;
fn ptr(&self) -> *const Self::Sys {
self.inner.ptr()
}
}
#[cfg(test)]
mod tests {
use super::Adapter;
use crate::{
session::{RunOptions, Session},
value::Tensor
};
#[test]
#[cfg(feature = "std")]
fn test_lora() -> crate::Result<()> {
let model = std::fs::read("tests/data/lora_model.onnx").expect("");
let mut session = Session::builder()?.commit_from_memory(&model)?;
let lora = Adapter::from_file("tests/data/adapter.orl", None)?;
let mut run_options = RunOptions::new()?;
run_options.add_adapter(&lora)?;
let output: Tensor<f32> = session
.run_with_options(crate::inputs![Tensor::<f32>::from_array(([4, 4], vec![1.0; 16]))?], &run_options)?
.remove("output")
.expect("")
.downcast()?;
let (_, output) = output.extract_tensor();
assert_eq!(output[0], 154.0);
assert_eq!(output[1], 176.0);
assert_eq!(output[2], 198.0);
assert_eq!(output[3], 220.0);
Ok(())
}
#[test]
fn test_lora_from_memory() -> crate::Result<()> {
let model = std::fs::read("tests/data/lora_model.onnx").expect("");
let mut session = Session::builder()?.commit_from_memory(&model)?;
let lora_bytes = std::fs::read("tests/data/adapter.orl").expect("");
let lora = Adapter::from_memory(&lora_bytes, None)?;
drop(lora_bytes);
let mut run_options = RunOptions::new()?;
run_options.add_adapter(&lora)?;
let output: Tensor<f32> = session
.run_with_options(crate::inputs![Tensor::<f32>::from_array(([4, 4], vec![1.0; 16]))?], &run_options)?
.remove("output")
.expect("")
.downcast()?;
let (_, output) = output.extract_tensor();
assert_eq!(output[0], 154.0);
assert_eq!(output[1], 176.0);
assert_eq!(output[2], 198.0);
assert_eq!(output[3], 220.0);
Ok(())
}
}