use crate::element::TensorElement;
use crate::environment::Environment;
use crate::io_binding::IoBinding;
use crate::memory::{MemoryInfo, MemoryInfoSnapshot};
use crate::prepacked::PrepackedWeightsContainer;
use crate::session::{IoDirection, LaneBufferPolicy, Session, lane_tensor_buffer};
use crate::session_options::SessionOptions;
use crate::tensor::TensorBuffer;
use crate::{Error, Result};
use std::ffi::CString;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TensorBufferAudit {
pub direction: IoDirection,
pub index: usize,
pub element_type: crate::ElementType,
pub element_count: usize,
pub byte_len: usize,
pub rust_ptr: usize,
pub ort_ptr: usize,
pub pointer_identity: bool,
pub memory_info: MemoryInfoSnapshot,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LaneHotPathAudit {
pub input_count: usize,
pub output_count: usize,
pub rebind_inputs_each_run: bool,
pub input_names_cached: bool,
pub inputs: Vec<TensorBufferAudit>,
pub outputs: Vec<TensorBufferAudit>,
}
fn audit_tensor_buffer<T: TensorElement>(
direction: IoDirection, index: usize, buffer: &TensorBuffer<T>,
) -> Result<TensorBufferAudit> {
let rust_ptr = buffer.as_slice().as_ptr() as usize;
let ort_ptr = buffer.engine_data_ptr()? as usize;
Ok(TensorBufferAudit {
direction,
index,
element_type: T::ELEM,
element_count: buffer.len(),
byte_len: buffer.byte_len()?,
rust_ptr,
ort_ptr,
pointer_identity: rust_ptr == ort_ptr,
memory_info: buffer.memory_info()?,
})
}
fn assert_tensor_buffer_zero_copy<T: TensorElement>(
what: &str, index: usize, buffer: &TensorBuffer<T>,
) -> Result<()> {
let audit = audit_tensor_buffer(IoDirection::Input, index, buffer)?;
if !audit.pointer_identity {
return Err(Error::new(
-1,
format!(
"zrt: {what} {index} is not zero-copy: rust_ptr=0x{:x}, ort_ptr=0x{:x}",
audit.rust_ptr, audit.ort_ptr
),
));
}
if !audit.memory_info.is_host_accessible() {
return Err(Error::new(
-1,
format!(
"zrt: {what} {index} is not host-accessible: {} device {} ({:?}/{:?})",
audit.memory_info.name,
audit.memory_info.device_id,
audit.memory_info.alloc_type,
audit.memory_info.mem_type
),
));
}
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RuntimeMode {
SharedSession,
ReplicatedSessions,
}
pub struct Lane<T: TensorElement> {
binding: IoBinding,
inputs: Vec<TensorBuffer<T>>,
outputs: Vec<TensorBuffer<T>>,
session: Arc<Session>,
}
pub struct StaticIoLane<
I: TensorElement,
O: TensorElement,
const INPUTS: usize,
const OUTPUTS: usize,
> {
binding: IoBinding,
inputs: [TensorBuffer<I>; INPUTS],
outputs: [TensorBuffer<O>; OUTPUTS],
input_names: [CString; INPUTS],
session: Arc<Session>,
rebind_inputs_each_run: bool,
}
impl<T> Lane<T>
where
T: TensorElement + Clone + Default,
{
pub(crate) fn new(
session: Arc<Session>, mem: &MemoryInfo, input_shapes: &[&[i64]], output_shapes: &[&[i64]],
policy: LaneBufferPolicy,
) -> Result<Self> {
if input_shapes.len() != session.input_count() {
return Err(Error::new(
-1,
format!(
"zrt: input shape count mismatch: expected {}, got {}",
session.input_count(),
input_shapes.len()
),
));
}
if output_shapes.len() != session.output_count() {
return Err(Error::new(
-1,
format!(
"zrt: output shape count mismatch: expected {}, got {}",
session.output_count(),
output_shapes.len()
),
));
}
let inputs: Vec<TensorBuffer<T>> = input_shapes
.iter()
.map(|shape| lane_tensor_buffer(shape, mem, policy))
.collect::<Result<_>>()?;
let outputs: Vec<TensorBuffer<T>> = output_shapes
.iter()
.map(|shape| lane_tensor_buffer(shape, mem, policy))
.collect::<Result<_>>()?;
let mut binding = IoBinding::new(&session)?;
for (i, input) in inputs.iter().enumerate() {
binding.bind_input(session.input_name(i)?, input)?;
}
for (i, output) in outputs.iter().enumerate() {
binding.bind_output_buffer(session.output_name(i)?, output)?;
}
Ok(Self {
binding,
inputs,
outputs,
session,
})
}
}
impl<T: TensorElement> Lane<T> {
#[inline]
pub fn run(&mut self) -> Result<()> {
self.session.run_binding(&self.binding)
}
#[inline]
pub fn run_unsynchronized(&mut self) -> Result<()> {
self.session.run_binding_unsynchronized(&self.binding)
}
pub fn prime(&mut self, runs: usize) -> Result<()> {
for _ in 0..runs {
self.run()?;
}
Ok(())
}
#[inline]
pub fn input(&self, i: usize) -> Result<&[T]> {
self.inputs
.get(i)
.map(TensorBuffer::as_slice)
.ok_or_else(|| Error::new(-1, format!("zrt: lane input index {i} out of range")))
}
#[inline]
pub fn input_mut(&mut self, i: usize) -> Result<&mut [T]> {
self.inputs
.get_mut(i)
.map(TensorBuffer::as_mut_slice)
.ok_or_else(|| Error::new(-1, format!("zrt: lane input index {i} out of range")))
}
#[inline]
pub fn output(&self, i: usize) -> Result<&[T]> {
self.outputs
.get(i)
.map(TensorBuffer::as_slice)
.ok_or_else(|| Error::new(-1, format!("zrt: lane output index {i} out of range")))
}
#[inline]
pub fn output_mut(&mut self, i: usize) -> Result<&mut [T]> {
self.outputs
.get_mut(i)
.map(TensorBuffer::as_mut_slice)
.ok_or_else(|| Error::new(-1, format!("zrt: lane output index {i} out of range")))
}
#[inline]
pub fn input_buffer(&self, i: usize) -> Result<&TensorBuffer<T>> {
self.inputs
.get(i)
.ok_or_else(|| Error::new(-1, format!("zrt: lane input index {i} out of range")))
}
#[inline]
pub fn output_buffer(&self, i: usize) -> Result<&TensorBuffer<T>> {
self.outputs
.get(i)
.ok_or_else(|| Error::new(-1, format!("zrt: lane output index {i} out of range")))
}
#[inline]
pub fn session(&self) -> &Session {
&self.session
}
pub fn audit_hot_path(&self) -> Result<LaneHotPathAudit> {
let inputs = self
.inputs
.iter()
.enumerate()
.map(|(i, buffer)| audit_tensor_buffer(IoDirection::Input, i, buffer))
.collect::<Result<Vec<_>>>()?;
let outputs = self
.outputs
.iter()
.enumerate()
.map(|(i, buffer)| audit_tensor_buffer(IoDirection::Output, i, buffer))
.collect::<Result<Vec<_>>>()?;
Ok(LaneHotPathAudit {
input_count: self.inputs.len(),
output_count: self.outputs.len(),
rebind_inputs_each_run: false,
input_names_cached: true,
inputs,
outputs,
})
}
pub fn assert_zero_copy_plan(&self) -> Result<()> {
for (i, input) in self.inputs.iter().enumerate() {
assert_tensor_buffer_zero_copy("lane input", i, input)?;
}
for (i, output) in self.outputs.iter().enumerate() {
assert_tensor_buffer_zero_copy("lane output", i, output)?;
}
Ok(())
}
}
impl<I, O, const INPUTS: usize, const OUTPUTS: usize> StaticIoLane<I, O, INPUTS, OUTPUTS>
where
I: TensorElement + Clone + Default,
O: TensorElement + Clone + Default,
{
pub fn new(
session: Arc<Session>, mem: &MemoryInfo, input_shapes: [&[i64]; INPUTS],
output_shapes: [&[i64]; OUTPUTS],
) -> Result<Self> {
Self::with_buffer_policy(
session,
mem,
mem,
input_shapes,
output_shapes,
LaneBufferPolicy::Auto,
LaneBufferPolicy::Auto,
)
}
pub fn with_memory(
session: Arc<Session>, input_mem: &MemoryInfo, output_mem: &MemoryInfo,
input_shapes: [&[i64]; INPUTS], output_shapes: [&[i64]; OUTPUTS],
) -> Result<Self> {
Self::with_buffer_policy(
session,
input_mem,
output_mem,
input_shapes,
output_shapes,
LaneBufferPolicy::Auto,
LaneBufferPolicy::Auto,
)
}
pub fn with_buffer_policy(
session: Arc<Session>, input_mem: &MemoryInfo, output_mem: &MemoryInfo,
input_shapes: [&[i64]; INPUTS], output_shapes: [&[i64]; OUTPUTS],
input_policy: LaneBufferPolicy, output_policy: LaneBufferPolicy,
) -> Result<Self> {
if INPUTS != session.input_count() {
return Err(Error::new(
-1,
format!(
"zrt: static I/O lane input count mismatch: expected {}, got {}",
session.input_count(),
INPUTS
),
));
}
if OUTPUTS != session.output_count() {
return Err(Error::new(
-1,
format!(
"zrt: static I/O lane output count mismatch: expected {}, got {}",
session.output_count(),
OUTPUTS
),
));
}
let inputs: [TensorBuffer<I>; INPUTS] = input_shapes
.iter()
.map(|shape| lane_tensor_buffer(shape, input_mem, input_policy))
.collect::<Result<Vec<_>>>()?
.try_into()
.map_err(|_| Error::new(-1, "zrt: failed to build static I/O input array"))?;
let outputs: [TensorBuffer<O>; OUTPUTS] = output_shapes
.iter()
.map(|shape| lane_tensor_buffer(shape, output_mem, output_policy))
.collect::<Result<Vec<_>>>()?
.try_into()
.map_err(|_| Error::new(-1, "zrt: failed to build static I/O output array"))?;
let mut binding = IoBinding::new(&session)?;
let input_names: [CString; INPUTS] = (0..INPUTS)
.map(|i| {
CString::new(session.input_name(i)?)
.map_err(|_| Error::new(-1, "zrt: input name contains a NUL"))
})
.collect::<Result<Vec<_>>>()?
.try_into()
.map_err(|_| Error::new(-1, "zrt: failed to build static I/O input name array"))?;
for (i, input) in inputs.iter().enumerate() {
binding.bind_input_cstr(&input_names[i], input)?;
}
for (i, output) in outputs.iter().enumerate() {
binding.bind_output_buffer(session.output_name(i)?, output)?;
}
Ok(Self {
binding,
inputs,
outputs,
input_names,
session,
rebind_inputs_each_run: false,
})
}
}
impl<I: TensorElement, O: TensorElement, const INPUTS: usize, const OUTPUTS: usize>
StaticIoLane<I, O, INPUTS, OUTPUTS>
{
#[inline]
pub fn run(&mut self) -> Result<()> {
if self.rebind_inputs_each_run {
self.binding.clear_inputs();
for (i, input) in self.inputs.iter().enumerate() {
self.binding.bind_input_cstr(&self.input_names[i], input)?;
}
}
self.session.run_binding(&self.binding)
}
#[inline]
pub fn run_unsynchronized(&mut self) -> Result<()> {
if self.rebind_inputs_each_run {
self.binding.clear_inputs();
for (i, input) in self.inputs.iter().enumerate() {
self.binding.bind_input_cstr(&self.input_names[i], input)?;
}
}
self.session.run_binding_unsynchronized(&self.binding)
}
pub fn prime(&mut self, runs: usize) -> Result<()> {
for _ in 0..runs {
self.run()?;
}
Ok(())
}
#[inline]
pub fn inputs(&self) -> &[TensorBuffer<I>; INPUTS] {
&self.inputs
}
#[inline]
pub fn inputs_mut(&mut self) -> &mut [TensorBuffer<I>; INPUTS] {
&mut self.inputs
}
#[inline]
pub fn outputs(&self) -> &[TensorBuffer<O>; OUTPUTS] {
&self.outputs
}
#[inline]
pub fn outputs_mut(&mut self) -> &mut [TensorBuffer<O>; OUTPUTS] {
&mut self.outputs
}
#[inline]
pub fn input(&self, i: usize) -> Result<&[I]> {
self.inputs
.get(i)
.map(TensorBuffer::as_slice)
.ok_or_else(|| {
Error::new(
-1,
format!("zrt: static I/O lane input index {i} out of range"),
)
})
}
#[inline]
pub fn input_mut(&mut self, i: usize) -> Result<&mut [I]> {
self.inputs
.get_mut(i)
.map(TensorBuffer::as_mut_slice)
.ok_or_else(|| {
Error::new(
-1,
format!("zrt: static I/O lane input index {i} out of range"),
)
})
}
#[inline]
pub fn output(&self, i: usize) -> Result<&[O]> {
self.outputs
.get(i)
.map(TensorBuffer::as_slice)
.ok_or_else(|| {
Error::new(
-1,
format!("zrt: static I/O lane output index {i} out of range"),
)
})
}
#[inline]
pub fn output_mut(&mut self, i: usize) -> Result<&mut [O]> {
self.outputs
.get_mut(i)
.map(TensorBuffer::as_mut_slice)
.ok_or_else(|| {
Error::new(
-1,
format!("zrt: static I/O lane output index {i} out of range"),
)
})
}
#[inline]
pub fn input_at<const IDX: usize>(&self) -> Result<&[I]> {
self.input(IDX)
}
#[inline]
pub fn input_mut_at<const IDX: usize>(&mut self) -> Result<&mut [I]> {
self.input_mut(IDX)
}
#[inline]
pub fn output_at<const IDX: usize>(&self) -> Result<&[O]> {
self.output(IDX)
}
#[inline]
pub fn output_mut_at<const IDX: usize>(&mut self) -> Result<&mut [O]> {
self.output_mut(IDX)
}
#[inline]
pub fn input_buffer(&self, i: usize) -> Result<&TensorBuffer<I>> {
self.inputs.get(i).ok_or_else(|| {
Error::new(
-1,
format!("zrt: static I/O lane input index {i} out of range"),
)
})
}
#[inline]
pub fn output_buffer(&self, i: usize) -> Result<&TensorBuffer<O>> {
self.outputs.get(i).ok_or_else(|| {
Error::new(
-1,
format!("zrt: static I/O lane output index {i} out of range"),
)
})
}
#[inline]
pub fn session(&self) -> &Session {
&self.session
}
pub fn audit_hot_path(&self) -> Result<LaneHotPathAudit> {
let inputs = self
.inputs
.iter()
.enumerate()
.map(|(i, buffer)| audit_tensor_buffer(IoDirection::Input, i, buffer))
.collect::<Result<Vec<_>>>()?;
let outputs = self
.outputs
.iter()
.enumerate()
.map(|(i, buffer)| audit_tensor_buffer(IoDirection::Output, i, buffer))
.collect::<Result<Vec<_>>>()?;
Ok(LaneHotPathAudit {
input_count: INPUTS,
output_count: OUTPUTS,
rebind_inputs_each_run: self.rebind_inputs_each_run,
input_names_cached: self.input_names.len() == INPUTS,
inputs,
outputs,
})
}
pub fn assert_zero_copy_plan(&self) -> Result<()> {
for (i, input) in self.inputs.iter().enumerate() {
assert_tensor_buffer_zero_copy("static I/O lane input", i, input)?;
}
for (i, output) in self.outputs.iter().enumerate() {
assert_tensor_buffer_zero_copy("static I/O lane output", i, output)?;
}
Ok(())
}
#[inline]
pub fn set_rebind_inputs_each_run(&mut self, enabled: bool) {
self.rebind_inputs_each_run = enabled;
}
}
fn build_shared_lanes<T>(
session: Arc<Session>, mem: &MemoryInfo, input_shapes: &[&[i64]], output_shapes: &[&[i64]],
lanes: usize, policy: LaneBufferPolicy, what: &'static str,
) -> Result<Vec<Lane<T>>>
where
T: TensorElement + Clone + Default,
{
if lanes == 0 {
return Err(Error::new(-1, format!("{what} requires at least one lane")));
}
(0..lanes)
.map(|_| Lane::new(session.clone(), mem, input_shapes, output_shapes, policy))
.collect()
}
pub struct Runtime<T: TensorElement> {
lanes: Vec<Lane<T>>,
mode: RuntimeMode,
}
pub struct StaticIoRuntime<
I: TensorElement,
O: TensorElement,
const INPUTS: usize,
const OUTPUTS: usize,
> {
lanes: Vec<StaticIoLane<I, O, INPUTS, OUTPUTS>>,
mode: RuntimeMode,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DynamicIoOptions {
pub max_buckets: usize,
pub input_policy: LaneBufferPolicy,
pub output_policy: LaneBufferPolicy,
pub rebind_inputs_each_run: bool,
}
impl DynamicIoOptions {
#[inline]
pub fn new(max_buckets: usize) -> Self {
Self {
max_buckets,
..Self::default()
}
}
#[inline]
pub fn with_input_policy(mut self, policy: LaneBufferPolicy) -> Self {
self.input_policy = policy;
self
}
#[inline]
pub fn with_output_policy(mut self, policy: LaneBufferPolicy) -> Self {
self.output_policy = policy;
self
}
#[inline]
pub fn with_rebind_inputs_each_run(mut self, enabled: bool) -> Self {
self.rebind_inputs_each_run = enabled;
self
}
fn validate(self) -> Result<Self> {
if self.max_buckets == 0 {
return Err(Error::new(
-1,
"DynamicIoRuntime requires at least one shape bucket",
));
}
Ok(self)
}
}
impl Default for DynamicIoOptions {
#[inline]
fn default() -> Self {
Self {
max_buckets: 16,
input_policy: LaneBufferPolicy::Auto,
output_policy: LaneBufferPolicy::Auto,
rebind_inputs_each_run: false,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct ShapeSpec<'a, const INPUTS: usize, const OUTPUTS: usize> {
pub input_shapes: [&'a [i64]; INPUTS],
pub output_shapes: [&'a [i64]; OUTPUTS],
}
impl<'a, const INPUTS: usize, const OUTPUTS: usize> ShapeSpec<'a, INPUTS, OUTPUTS> {
#[inline]
pub fn new(input_shapes: [&'a [i64]; INPUTS], output_shapes: [&'a [i64]; OUTPUTS]) -> Self {
Self {
input_shapes,
output_shapes,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ShapeKey<const INPUTS: usize, const OUTPUTS: usize> {
input_shapes: [Vec<i64>; INPUTS],
output_shapes: [Vec<i64>; OUTPUTS],
}
impl<const INPUTS: usize, const OUTPUTS: usize> ShapeKey<INPUTS, OUTPUTS> {
pub fn new(input_shapes: [&[i64]; INPUTS], output_shapes: [&[i64]; OUTPUTS]) -> Self {
Self {
input_shapes: input_shapes.map(<[i64]>::to_vec),
output_shapes: output_shapes.map(<[i64]>::to_vec),
}
}
#[inline]
pub fn input_shape(&self, i: usize) -> Option<&[i64]> {
self.input_shapes.get(i).map(Vec::as_slice)
}
#[inline]
pub fn output_shape(&self, i: usize) -> Option<&[i64]> {
self.output_shapes.get(i).map(Vec::as_slice)
}
#[inline]
pub fn input_shapes(&self) -> &[Vec<i64>; INPUTS] {
&self.input_shapes
}
#[inline]
pub fn output_shapes(&self) -> &[Vec<i64>; OUTPUTS] {
&self.output_shapes
}
#[inline]
fn matches(&self, input_shapes: [&[i64]; INPUTS], output_shapes: [&[i64]; OUTPUTS]) -> bool {
self.input_shapes
.iter()
.zip(input_shapes)
.all(|(a, b)| a.as_slice() == b)
&& self
.output_shapes
.iter()
.zip(output_shapes)
.all(|(a, b)| a.as_slice() == b)
}
}
pub struct ShapeBucket<
I: TensorElement,
O: TensorElement,
const INPUTS: usize,
const OUTPUTS: usize,
> {
key: ShapeKey<INPUTS, OUTPUTS>,
lanes: StaticIoRuntime<I, O, INPUTS, OUTPUTS>,
last_used: u64,
}
impl<I: TensorElement, O: TensorElement, const INPUTS: usize, const OUTPUTS: usize>
ShapeBucket<I, O, INPUTS, OUTPUTS>
{
#[inline]
pub fn key(&self) -> &ShapeKey<INPUTS, OUTPUTS> {
&self.key
}
#[inline]
pub fn last_used(&self) -> u64 {
self.last_used
}
#[inline]
pub fn lanes(&self) -> &StaticIoRuntime<I, O, INPUTS, OUTPUTS> {
&self.lanes
}
#[inline]
pub fn lanes_mut(&mut self) -> &mut StaticIoRuntime<I, O, INPUTS, OUTPUTS> {
&mut self.lanes
}
#[inline]
pub fn lane(&self, i: usize) -> Result<&StaticIoLane<I, O, INPUTS, OUTPUTS>> {
self.lanes.lane(i)
}
#[inline]
pub fn lane_mut(&mut self, i: usize) -> Result<&mut StaticIoLane<I, O, INPUTS, OUTPUTS>> {
self.lanes.lane_mut(i)
}
#[inline]
pub fn run_on<R>(
&mut self, i: usize, f: impl FnOnce(&mut StaticIoLane<I, O, INPUTS, OUTPUTS>) -> Result<R>,
) -> Result<R> {
self.lanes.run_on(i, f)
}
pub fn prime(&mut self, runs: usize) -> Result<()> {
self.lanes.prime(runs)
}
pub fn audit_hot_path(&self) -> Result<Vec<LaneHotPathAudit>> {
self.lanes.audit_hot_path()
}
pub fn assert_zero_copy_plan(&self) -> Result<()> {
self.lanes.assert_zero_copy_plan()
}
}
enum DynamicSessions {
Shared(Arc<Session>),
Replicated(Vec<Arc<Session>>),
}
pub struct DynamicIoRuntime<
I: TensorElement,
O: TensorElement,
const INPUTS: usize,
const OUTPUTS: usize,
> {
sessions: DynamicSessions,
input_mem: MemoryInfo,
output_mem: MemoryInfo,
options: DynamicIoOptions,
lane_count: usize,
buckets: Vec<ShapeBucket<I, O, INPUTS, OUTPUTS>>,
hot_bucket: Option<usize>,
tick: u64,
}
impl<T> Runtime<T>
where
T: TensorElement + Clone + Default,
{
pub fn shared_session(
session: Arc<Session>, mem: &MemoryInfo, input_shapes: &[&[i64]], output_shapes: &[&[i64]],
lanes: usize,
) -> Result<Self> {
Self::from_shared_session(session, mem, input_shapes, output_shapes, lanes)
}
pub fn shared_session_with_buffer_policy(
session: Arc<Session>, mem: &MemoryInfo, input_shapes: &[&[i64]], output_shapes: &[&[i64]],
lanes: usize, policy: LaneBufferPolicy,
) -> Result<Self> {
Self::from_shared_session_with_buffer_policy(
session,
mem,
input_shapes,
output_shapes,
lanes,
policy,
)
}
pub fn from_shared_session(
session: Arc<Session>, mem: &MemoryInfo, input_shapes: &[&[i64]], output_shapes: &[&[i64]],
lanes: usize,
) -> Result<Self> {
Self::from_shared_session_with_buffer_policy(
session,
mem,
input_shapes,
output_shapes,
lanes,
LaneBufferPolicy::Auto,
)
}
pub fn from_shared_session_with_buffer_policy(
session: Arc<Session>, mem: &MemoryInfo, input_shapes: &[&[i64]], output_shapes: &[&[i64]],
lanes: usize, policy: LaneBufferPolicy,
) -> Result<Self> {
let lanes = build_shared_lanes(
session,
mem,
input_shapes,
output_shapes,
lanes,
policy,
"Runtime",
)?;
Ok(Self {
lanes,
mode: RuntimeMode::SharedSession,
})
}
pub fn from_sessions(
sessions: Vec<Session>, mem: &MemoryInfo, input_shapes: &[&[i64]], output_shapes: &[&[i64]],
) -> Result<Self> {
Self::from_sessions_with_buffer_policy(
sessions,
mem,
input_shapes,
output_shapes,
LaneBufferPolicy::Auto,
)
}
pub fn from_sessions_with_buffer_policy(
sessions: Vec<Session>, mem: &MemoryInfo, input_shapes: &[&[i64]],
output_shapes: &[&[i64]], policy: LaneBufferPolicy,
) -> Result<Self> {
if sessions.is_empty() {
return Err(Error::new(-1, "Runtime requires at least one session"));
}
let lanes = sessions
.into_iter()
.map(|session| Lane::new(Arc::new(session), mem, input_shapes, output_shapes, policy))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
lanes,
mode: RuntimeMode::ReplicatedSessions,
})
}
pub fn from_session_factory<F>(
lanes: usize, mem: &MemoryInfo, input_shapes: &[&[i64]], output_shapes: &[&[i64]],
factory: F,
) -> Result<Self>
where
F: FnMut(usize) -> Result<Session>,
{
Self::from_session_factory_with_buffer_policy(
lanes,
mem,
input_shapes,
output_shapes,
LaneBufferPolicy::Auto,
factory,
)
}
pub fn from_session_factory_with_buffer_policy<F>(
lanes: usize, mem: &MemoryInfo, input_shapes: &[&[i64]], output_shapes: &[&[i64]],
policy: LaneBufferPolicy, mut factory: F,
) -> Result<Self>
where
F: FnMut(usize) -> Result<Session>,
{
if lanes == 0 {
return Err(Error::new(-1, "Runtime requires at least one lane"));
}
let sessions = (0..lanes).map(&mut factory).collect::<Result<Vec<_>>>()?;
Self::from_sessions_with_buffer_policy(sessions, mem, input_shapes, output_shapes, policy)
}
pub fn replicated_sessions(
env: &Environment, model_path: &str, opts: SessionOptions, mem: &MemoryInfo,
input_shapes: &[&[i64]], output_shapes: &[&[i64]], lanes: usize,
) -> Result<Self> {
Self::from_session_factory(lanes, mem, input_shapes, output_shapes, |_| {
Session::new(env, model_path, opts.clone())
})
}
#[allow(clippy::too_many_arguments)]
pub fn replicated_sessions_with_buffer_policy(
env: &Environment, model_path: &str, opts: SessionOptions, mem: &MemoryInfo,
input_shapes: &[&[i64]], output_shapes: &[&[i64]], lanes: usize, policy: LaneBufferPolicy,
) -> Result<Self> {
Self::from_session_factory_with_buffer_policy(
lanes,
mem,
input_shapes,
output_shapes,
policy,
|_| Session::new(env, model_path, opts.clone()),
)
}
#[allow(clippy::too_many_arguments)]
pub fn replicated_sessions_with_prepacked_weights(
env: &Environment, model_path: &str, opts: SessionOptions,
prepacked: &PrepackedWeightsContainer, mem: &MemoryInfo, input_shapes: &[&[i64]],
output_shapes: &[&[i64]], lanes: usize,
) -> Result<Self> {
Self::from_session_factory(lanes, mem, input_shapes, output_shapes, |_| {
Session::new_with_prepacked_weights(env, model_path, opts.clone(), prepacked)
})
}
#[allow(clippy::too_many_arguments)]
pub fn replicated_sessions_with_prepacked_weights_and_buffer_policy(
env: &Environment, model_path: &str, opts: SessionOptions,
prepacked: &PrepackedWeightsContainer, mem: &MemoryInfo, input_shapes: &[&[i64]],
output_shapes: &[&[i64]], lanes: usize, policy: LaneBufferPolicy,
) -> Result<Self> {
Self::from_session_factory_with_buffer_policy(
lanes,
mem,
input_shapes,
output_shapes,
policy,
|_| Session::new_with_prepacked_weights(env, model_path, opts.clone(), prepacked),
)
}
}
impl<T: TensorElement> Runtime<T> {
#[inline]
pub fn len(&self) -> usize {
self.lanes.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.lanes.is_empty()
}
#[inline]
pub fn session_mode(&self) -> RuntimeMode {
self.mode
}
#[inline]
pub fn lanes(&self) -> &[Lane<T>] {
&self.lanes
}
#[inline]
pub fn lanes_mut(&mut self) -> &mut [Lane<T>] {
&mut self.lanes
}
#[inline]
pub fn lane(&self, i: usize) -> Result<&Lane<T>> {
self.lanes
.get(i)
.ok_or_else(|| Error::new(-1, format!("zrt: lane index {i} out of range")))
}
#[inline]
pub fn lane_mut(&mut self, i: usize) -> Result<&mut Lane<T>> {
self.lanes
.get_mut(i)
.ok_or_else(|| Error::new(-1, format!("zrt: lane index {i} out of range")))
}
#[inline]
pub fn into_lanes(self) -> Vec<Lane<T>> {
self.lanes
}
#[inline]
pub fn run_on<R>(&mut self, i: usize, f: impl FnOnce(&mut Lane<T>) -> Result<R>) -> Result<R> {
f(self.lane_mut(i)?)
}
pub fn prime(&mut self, runs: usize) -> Result<()> {
for lane in &mut self.lanes {
lane.prime(runs)?;
}
Ok(())
}
pub fn audit_hot_path(&self) -> Result<Vec<LaneHotPathAudit>> {
self.lanes.iter().map(Lane::audit_hot_path).collect()
}
pub fn assert_zero_copy_plan(&self) -> Result<()> {
for lane in &self.lanes {
lane.assert_zero_copy_plan()?;
}
Ok(())
}
#[inline]
pub fn into_lane_set(self) -> Self {
self
}
}
impl<I, O, const INPUTS: usize, const OUTPUTS: usize> StaticIoRuntime<I, O, INPUTS, OUTPUTS>
where
I: TensorElement + Clone + Default,
O: TensorElement + Clone + Default,
{
pub fn shared_session(
session: Arc<Session>, mem: &MemoryInfo, input_shapes: [&[i64]; INPUTS],
output_shapes: [&[i64]; OUTPUTS], lanes: usize,
) -> Result<Self> {
Self::shared_session_with_buffer_policy(
session,
mem,
mem,
input_shapes,
output_shapes,
lanes,
LaneBufferPolicy::Auto,
LaneBufferPolicy::Auto,
)
}
#[allow(clippy::too_many_arguments)]
pub fn shared_session_with_buffer_policy(
session: Arc<Session>, input_mem: &MemoryInfo, output_mem: &MemoryInfo,
input_shapes: [&[i64]; INPUTS], output_shapes: [&[i64]; OUTPUTS], lanes: usize,
input_policy: LaneBufferPolicy, output_policy: LaneBufferPolicy,
) -> Result<Self> {
if lanes == 0 {
return Err(Error::new(-1, "StaticIoRuntime requires at least one lane"));
}
let lanes = (0..lanes)
.map(|_| {
StaticIoLane::with_buffer_policy(
session.clone(),
input_mem,
output_mem,
input_shapes,
output_shapes,
input_policy,
output_policy,
)
})
.collect::<Result<Vec<_>>>()?;
Ok(Self {
lanes,
mode: RuntimeMode::SharedSession,
})
}
pub fn from_sessions(
sessions: Vec<Session>, mem: &MemoryInfo, input_shapes: [&[i64]; INPUTS],
output_shapes: [&[i64]; OUTPUTS],
) -> Result<Self> {
Self::from_sessions_with_buffer_policy(
sessions,
mem,
mem,
input_shapes,
output_shapes,
LaneBufferPolicy::Auto,
LaneBufferPolicy::Auto,
)
}
#[allow(clippy::too_many_arguments)]
pub fn from_sessions_with_buffer_policy(
sessions: Vec<Session>, input_mem: &MemoryInfo, output_mem: &MemoryInfo,
input_shapes: [&[i64]; INPUTS], output_shapes: [&[i64]; OUTPUTS],
input_policy: LaneBufferPolicy, output_policy: LaneBufferPolicy,
) -> Result<Self> {
if sessions.is_empty() {
return Err(Error::new(
-1,
"StaticIoRuntime requires at least one session",
));
}
let lanes = sessions
.into_iter()
.map(|session| {
StaticIoLane::with_buffer_policy(
Arc::new(session),
input_mem,
output_mem,
input_shapes,
output_shapes,
input_policy,
output_policy,
)
})
.collect::<Result<Vec<_>>>()?;
Ok(Self {
lanes,
mode: RuntimeMode::ReplicatedSessions,
})
}
pub fn from_session_arcs(
sessions: &[Arc<Session>], mem: &MemoryInfo, input_shapes: [&[i64]; INPUTS],
output_shapes: [&[i64]; OUTPUTS],
) -> Result<Self> {
Self::from_session_arcs_with_buffer_policy(
sessions,
mem,
mem,
input_shapes,
output_shapes,
LaneBufferPolicy::Auto,
LaneBufferPolicy::Auto,
)
}
#[allow(clippy::too_many_arguments)]
pub fn from_session_arcs_with_buffer_policy(
sessions: &[Arc<Session>], input_mem: &MemoryInfo, output_mem: &MemoryInfo,
input_shapes: [&[i64]; INPUTS], output_shapes: [&[i64]; OUTPUTS],
input_policy: LaneBufferPolicy, output_policy: LaneBufferPolicy,
) -> Result<Self> {
if sessions.is_empty() {
return Err(Error::new(
-1,
"StaticIoRuntime requires at least one session",
));
}
let lanes = sessions
.iter()
.map(|session| {
StaticIoLane::with_buffer_policy(
session.clone(),
input_mem,
output_mem,
input_shapes,
output_shapes,
input_policy,
output_policy,
)
})
.collect::<Result<Vec<_>>>()?;
Ok(Self {
lanes,
mode: RuntimeMode::ReplicatedSessions,
})
}
pub fn from_session_factory<F>(
lanes: usize, mem: &MemoryInfo, input_shapes: [&[i64]; INPUTS],
output_shapes: [&[i64]; OUTPUTS], factory: F,
) -> Result<Self>
where
F: FnMut(usize) -> Result<Session>,
{
Self::from_session_factory_with_buffer_policy(
lanes,
mem,
mem,
input_shapes,
output_shapes,
LaneBufferPolicy::Auto,
LaneBufferPolicy::Auto,
factory,
)
}
#[allow(clippy::too_many_arguments)]
pub fn from_session_factory_with_buffer_policy<F>(
lanes: usize, input_mem: &MemoryInfo, output_mem: &MemoryInfo,
input_shapes: [&[i64]; INPUTS], output_shapes: [&[i64]; OUTPUTS],
input_policy: LaneBufferPolicy, output_policy: LaneBufferPolicy, mut factory: F,
) -> Result<Self>
where
F: FnMut(usize) -> Result<Session>,
{
if lanes == 0 {
return Err(Error::new(-1, "StaticIoRuntime requires at least one lane"));
}
let sessions = (0..lanes).map(&mut factory).collect::<Result<Vec<_>>>()?;
Self::from_sessions_with_buffer_policy(
sessions,
input_mem,
output_mem,
input_shapes,
output_shapes,
input_policy,
output_policy,
)
}
}
impl<I: TensorElement, O: TensorElement, const INPUTS: usize, const OUTPUTS: usize>
StaticIoRuntime<I, O, INPUTS, OUTPUTS>
{
#[inline]
pub fn len(&self) -> usize {
self.lanes.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.lanes.is_empty()
}
#[inline]
pub fn session_mode(&self) -> RuntimeMode {
self.mode
}
#[inline]
pub fn lanes(&self) -> &[StaticIoLane<I, O, INPUTS, OUTPUTS>] {
&self.lanes
}
#[inline]
pub fn lanes_mut(&mut self) -> &mut [StaticIoLane<I, O, INPUTS, OUTPUTS>] {
&mut self.lanes
}
#[inline]
pub fn lane(&self, i: usize) -> Result<&StaticIoLane<I, O, INPUTS, OUTPUTS>> {
self.lanes
.get(i)
.ok_or_else(|| Error::new(-1, format!("zrt: static I/O lane index {i} out of range")))
}
#[inline]
pub fn lane_mut(&mut self, i: usize) -> Result<&mut StaticIoLane<I, O, INPUTS, OUTPUTS>> {
self.lanes
.get_mut(i)
.ok_or_else(|| Error::new(-1, format!("zrt: static I/O lane index {i} out of range")))
}
#[inline]
pub fn into_lanes(self) -> Vec<StaticIoLane<I, O, INPUTS, OUTPUTS>> {
self.lanes
}
#[inline]
pub fn run_on<R>(
&mut self, i: usize, f: impl FnOnce(&mut StaticIoLane<I, O, INPUTS, OUTPUTS>) -> Result<R>,
) -> Result<R> {
f(self.lane_mut(i)?)
}
pub fn prime(&mut self, runs: usize) -> Result<()> {
for lane in &mut self.lanes {
lane.prime(runs)?;
}
Ok(())
}
#[inline]
pub fn set_rebind_inputs_each_run(&mut self, enabled: bool) {
for lane in &mut self.lanes {
lane.set_rebind_inputs_each_run(enabled);
}
}
pub fn audit_hot_path(&self) -> Result<Vec<LaneHotPathAudit>> {
self.lanes
.iter()
.map(StaticIoLane::audit_hot_path)
.collect()
}
pub fn assert_zero_copy_plan(&self) -> Result<()> {
for lane in &self.lanes {
lane.assert_zero_copy_plan()?;
}
Ok(())
}
}
impl<I, O, const INPUTS: usize, const OUTPUTS: usize> DynamicIoRuntime<I, O, INPUTS, OUTPUTS>
where
I: TensorElement + Clone + Default,
O: TensorElement + Clone + Default,
{
pub fn shared_session(session: Arc<Session>, mem: MemoryInfo, lanes: usize) -> Result<Self> {
let output_mem = mem.try_clone_descriptor()?;
Self::shared_session_with_options(
session,
mem,
output_mem,
lanes,
DynamicIoOptions::default(),
)
}
pub fn shared_session_with_options(
session: Arc<Session>, input_mem: MemoryInfo, output_mem: MemoryInfo, lanes: usize,
options: DynamicIoOptions,
) -> Result<Self> {
if lanes == 0 {
return Err(Error::new(
-1,
"DynamicIoRuntime requires at least one lane",
));
}
Ok(Self {
sessions: DynamicSessions::Shared(session),
input_mem,
output_mem,
options: options.validate()?,
lane_count: lanes,
buckets: Vec::new(),
hot_bucket: None,
tick: 0,
})
}
pub fn from_sessions(sessions: Vec<Session>, mem: MemoryInfo) -> Result<Self> {
let output_mem = mem.try_clone_descriptor()?;
Self::from_sessions_with_options(sessions, mem, output_mem, DynamicIoOptions::default())
}
pub fn from_sessions_with_options(
sessions: Vec<Session>, input_mem: MemoryInfo, output_mem: MemoryInfo,
options: DynamicIoOptions,
) -> Result<Self> {
let sessions = sessions.into_iter().map(Arc::new).collect::<Vec<_>>();
Self::from_session_arcs_with_options(sessions, input_mem, output_mem, options)
}
pub fn from_session_arcs(sessions: Vec<Arc<Session>>, mem: MemoryInfo) -> Result<Self> {
let output_mem = mem.try_clone_descriptor()?;
Self::from_session_arcs_with_options(sessions, mem, output_mem, DynamicIoOptions::default())
}
pub fn from_session_arcs_with_options(
sessions: Vec<Arc<Session>>, input_mem: MemoryInfo, output_mem: MemoryInfo,
options: DynamicIoOptions,
) -> Result<Self> {
if sessions.is_empty() {
return Err(Error::new(
-1,
"DynamicIoRuntime requires at least one session",
));
}
let lane_count = sessions.len();
Ok(Self {
sessions: DynamicSessions::Replicated(sessions),
input_mem,
output_mem,
options: options.validate()?,
lane_count,
buckets: Vec::new(),
hot_bucket: None,
tick: 0,
})
}
pub fn from_session_factory<F>(lanes: usize, mem: MemoryInfo, factory: F) -> Result<Self>
where
F: FnMut(usize) -> Result<Session>,
{
let output_mem = mem.try_clone_descriptor()?;
Self::from_session_factory_with_options(
lanes,
mem,
output_mem,
DynamicIoOptions::default(),
factory,
)
}
pub fn from_session_factory_with_options<F>(
lanes: usize, input_mem: MemoryInfo, output_mem: MemoryInfo, options: DynamicIoOptions,
mut factory: F,
) -> Result<Self>
where
F: FnMut(usize) -> Result<Session>,
{
if lanes == 0 {
return Err(Error::new(
-1,
"DynamicIoRuntime requires at least one lane",
));
}
let sessions = (0..lanes).map(&mut factory).collect::<Result<Vec<_>>>()?;
Self::from_sessions_with_options(sessions, input_mem, output_mem, options)
}
fn next_tick(&mut self) -> u64 {
self.tick = self.tick.wrapping_add(1).max(1);
self.tick
}
fn build_lane_set(
&self, input_shapes: [&[i64]; INPUTS], output_shapes: [&[i64]; OUTPUTS],
) -> Result<StaticIoRuntime<I, O, INPUTS, OUTPUTS>> {
let mut lanes = match &self.sessions {
DynamicSessions::Shared(session) => StaticIoRuntime::shared_session_with_buffer_policy(
session.clone(),
&self.input_mem,
&self.output_mem,
input_shapes,
output_shapes,
self.lane_count,
self.options.input_policy,
self.options.output_policy,
),
DynamicSessions::Replicated(sessions) => {
StaticIoRuntime::from_session_arcs_with_buffer_policy(
sessions,
&self.input_mem,
&self.output_mem,
input_shapes,
output_shapes,
self.options.input_policy,
self.options.output_policy,
)
},
}?;
lanes.set_rebind_inputs_each_run(self.options.rebind_inputs_each_run);
Ok(lanes)
}
fn find_bucket_index(
&self, input_shapes: [&[i64]; INPUTS], output_shapes: [&[i64]; OUTPUTS],
) -> Option<usize> {
if self.buckets.len() > 1 {
if let Some(i) = self.hot_bucket {
if self
.buckets
.get(i)
.is_some_and(|bucket| bucket.key.matches(input_shapes, output_shapes))
{
return Some(i);
}
}
}
self.buckets
.iter()
.position(|bucket| bucket.key.matches(input_shapes, output_shapes))
}
fn evict_one_bucket_if_full(&mut self) {
if self.buckets.len() < self.options.max_buckets {
return;
}
self.hot_bucket = None;
if let Some((oldest, _)) = self
.buckets
.iter()
.enumerate()
.min_by_key(|(_, bucket)| bucket.last_used)
{
self.buckets.swap_remove(oldest);
}
}
pub fn get_or_create_bucket(
&mut self, input_shapes: [&[i64]; INPUTS], output_shapes: [&[i64]; OUTPUTS],
) -> Result<&mut ShapeBucket<I, O, INPUTS, OUTPUTS>> {
if let Some(i) = self.find_bucket_index(input_shapes, output_shapes) {
let tick = self.next_tick();
self.buckets[i].last_used = tick;
self.hot_bucket = Some(i);
return Ok(&mut self.buckets[i]);
}
self.evict_one_bucket_if_full();
let key = ShapeKey::new(input_shapes, output_shapes);
let lanes = self.build_lane_set(input_shapes, output_shapes)?;
let last_used = self.next_tick();
self.buckets.push(ShapeBucket {
key,
lanes,
last_used,
});
self.hot_bucket = Some(self.buckets.len() - 1);
self.buckets.last_mut().ok_or_else(|| {
Error::new(
-1,
"zrt: failed to access newly created dynamic shape bucket",
)
})
}
}
impl<I, O, const INPUTS: usize, const OUTPUTS: usize> DynamicIoRuntime<I, O, INPUTS, OUTPUTS>
where
I: TensorElement + Clone + Default,
O: TensorElement + Clone + Default,
{
#[inline]
pub fn bucket_count(&self) -> usize {
self.buckets.len()
}
#[inline]
pub fn max_buckets(&self) -> usize {
self.options.max_buckets
}
#[inline]
pub fn lane_count(&self) -> usize {
self.lane_count
}
#[inline]
pub fn session_mode(&self) -> RuntimeMode {
match &self.sessions {
DynamicSessions::Shared(_) => RuntimeMode::SharedSession,
DynamicSessions::Replicated(_) => RuntimeMode::ReplicatedSessions,
}
}
#[inline]
pub fn options(&self) -> DynamicIoOptions {
self.options
}
#[inline]
pub fn buckets(&self) -> &[ShapeBucket<I, O, INPUTS, OUTPUTS>] {
&self.buckets
}
#[inline]
pub fn buckets_mut(&mut self) -> &mut [ShapeBucket<I, O, INPUTS, OUTPUTS>] {
&mut self.buckets
}
#[inline]
pub fn clear_buckets(&mut self) {
self.hot_bucket = None;
self.buckets.clear();
}
#[inline]
pub fn bucket(
&self, input_shapes: [&[i64]; INPUTS], output_shapes: [&[i64]; OUTPUTS],
) -> Option<&ShapeBucket<I, O, INPUTS, OUTPUTS>> {
self.find_bucket_index(input_shapes, output_shapes)
.map(|i| &self.buckets[i])
}
pub fn bucket_mut(
&mut self, input_shapes: [&[i64]; INPUTS], output_shapes: [&[i64]; OUTPUTS],
) -> Option<&mut ShapeBucket<I, O, INPUTS, OUTPUTS>> {
let i = self.find_bucket_index(input_shapes, output_shapes)?;
let tick = self.next_tick();
self.buckets[i].last_used = tick;
self.hot_bucket = Some(i);
Some(&mut self.buckets[i])
}
pub fn remove_bucket(
&mut self, input_shapes: [&[i64]; INPUTS], output_shapes: [&[i64]; OUTPUTS],
) -> bool {
let Some(i) = self.find_bucket_index(input_shapes, output_shapes) else {
return false;
};
self.hot_bucket = None;
self.buckets.swap_remove(i);
true
}
pub fn prebuild_buckets<'a>(
&mut self, specs: impl IntoIterator<Item = ShapeSpec<'a, INPUTS, OUTPUTS>>,
) -> Result<usize> {
let mut count = 0usize;
for spec in specs {
self.get_or_create_bucket(spec.input_shapes, spec.output_shapes)?;
count += 1;
}
Ok(count)
}
pub fn prime_bucket(
&mut self, input_shapes: [&[i64]; INPUTS], output_shapes: [&[i64]; OUTPUTS], runs: usize,
) -> Result<()> {
self.get_or_create_bucket(input_shapes, output_shapes)?
.prime(runs)
}
pub fn prime_cached_buckets(&mut self, runs: usize) -> Result<()> {
for bucket in &mut self.buckets {
bucket.prime(runs)?;
}
Ok(())
}
pub fn warm_buckets<'a>(
&mut self, specs: impl IntoIterator<Item = ShapeSpec<'a, INPUTS, OUTPUTS>>, runs: usize,
) -> Result<usize> {
let mut count = 0usize;
for spec in specs {
self.get_or_create_bucket(spec.input_shapes, spec.output_shapes)?
.prime(runs)?;
count += 1;
}
Ok(count)
}
#[inline]
pub fn run_on<R>(
&mut self, input_shapes: [&[i64]; INPUTS], output_shapes: [&[i64]; OUTPUTS], lane: usize,
f: impl FnOnce(&mut StaticIoLane<I, O, INPUTS, OUTPUTS>) -> Result<R>,
) -> Result<R> {
self.get_or_create_bucket(input_shapes, output_shapes)?
.run_on(lane, f)
}
pub fn audit_cached_hot_paths(&self) -> Result<Vec<Vec<LaneHotPathAudit>>> {
self.buckets
.iter()
.map(ShapeBucket::audit_hot_path)
.collect()
}
pub fn assert_cached_zero_copy_plan(&self) -> Result<()> {
for bucket in &self.buckets {
bucket.assert_zero_copy_plan()?;
}
Ok(())
}
}