#![allow(missing_docs)]
pub use crate::ffi::ml_v1::SnapshotFormat;
use crate::{ffi::ml_inference_v0 as ffi_v0, Error};
use ffi_v0::AgentId;
use ffi_v0::LoadInferenceHandle;
use ffi_v0::RawInferenceHandle;
use ffi_v0::RuntimeHandle;
#[doc(hidden)]
pub use ffi_v0::API as FFI_API;
use std::collections::HashMap;
use crossbeam_channel::{unbounded, Receiver, Sender};
use std::task::Poll;
pub type AgentData = HashMap<String, Vec<f32>>;
#[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)]
pub struct BrainId(ffi_v0::RawInferenceHandle);
pub struct InferenceFuture {
recv: Receiver<InferenceInstance>,
}
impl std::future::Future for InferenceFuture {
type Output = Result<InferenceInstance, Error>;
fn poll(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
match self.recv.try_recv() {
Ok(handle) => Poll::Ready(Ok(handle)),
Err(crossbeam_channel::TryRecvError::Empty) => Poll::Pending,
Err(crossbeam_channel::TryRecvError::Disconnected) => {
Poll::Ready(Err(Error::Unavailable))
}
}
}
}
struct PendingBrain {
handle: LoadInferenceHandle,
receiver: Sender<InferenceInstance>,
}
struct Brain {
handle: RawInferenceHandle,
}
pub struct InferenceRuntime {
handle: RuntimeHandle,
pending_brains: Vec<PendingBrain>,
}
#[derive(Debug, Clone, Hash, Eq, PartialEq, Copy)]
#[cfg_attr(feature = "with_speedy", derive(speedy::Writable, speedy::Readable))]
pub struct InferenceHandle {
runtime: RuntimeHandle,
inference: RawInferenceHandle,
}
pub struct InferenceInstance {
ref_count: usize,
runtime: RuntimeHandle,
inference: RawInferenceHandle,
}
impl Drop for InferenceInstance {
fn drop(&mut self) {
if self.ref_count != 0 {
log::error!("inference instance being dropped with existing references: {:?}", self.ref_count);
}
if let Err(e) = ffi_v0::unload_inference(self.runtime, self.inference) {
log::error!("failed unloading inference: {e:?}");
}
}
}
impl InferenceInstance {
pub fn handle(&self) -> InferenceHandle {
InferenceHandle {
runtime: self.runtime,
inference: self.inference,
}
}
pub fn increment_refcount(&mut self) -> usize {
self.ref_count += 1;
self.ref_count
}
pub fn decrease_refcount(&mut self) -> usize {
self.ref_count -= 1;
self.ref_count
}
pub fn submit_infer(
&mut self,
agent_id: AgentId,
states: &HashMap<String, Vec<f32>>,
) -> Result<(), Error> {
ffi_v0::begin_infer_submission(self.runtime, self.inference, agent_id)?;
for (k, v) in states {
ffi_v0::submit_inference_item(self.runtime, self.inference, k, v)
.map_err(Error::from)?;
}
ffi_v0::end_infer_submission(self.runtime, self.inference).map_err(Error::from)
}
pub fn poll_results(&self) -> Result<Vec<(AgentId, AgentData)>, Error> {
let mut output = vec![];
let brain_agents = decode_u64(ffi_v0::result_get_brain_agents(
self.runtime,
self.inference,
)?);
let brain_outputs: Vec<String> =
ffi_v0::result_get_brain_output_names(self.runtime, self.inference)
.map_err(Error::from)
.and_then(decode_strings)?;
for agent_id in brain_agents {
let mut agent_data = HashMap::default();
for (idx, name) in brain_outputs.iter().enumerate() {
agent_data.insert(
name.clone(),
decode_f32(ffi_v0::result_get_brain_agent_output(
self.runtime,
self.inference,
agent_id,
idx as u32,
)?),
);
}
output.push((agent_id, agent_data));
}
Ok(output)
}
}
impl Default for InferenceRuntime {
fn default() -> Self {
Self::create()
}
}
impl InferenceRuntime {
pub fn create() -> Self {
let handle = ffi_v0::create_runtime();
Self {
handle,
pending_brains: vec![],
}
}
pub fn execute_inference(&self, limit_s: Option<f32>) -> Result<Vec<u64>, Error> {
let agents_bytes = if let Some(limit) = limit_s {
ffi_v0::execute_inference_for(self.handle, limit).map_err(Error::from)
} else {
ffi_v0::execute_inference(self.handle).map_err(Error::from)
}?;
let agent_ids = decode_u64(agents_bytes);
Ok(agent_ids)
}
pub fn load(&mut self, format: SnapshotFormat, data: &[u8]) -> Result<InferenceFuture, Error> {
let (tx, rx) = unbounded();
let handle = ffi_v0::load_inference_async(self.handle, format as u32, data)?;
self.pending_brains.push(PendingBrain {
handle,
receiver: tx,
});
Ok(InferenceFuture { recv: rx })
}
pub fn update(&mut self) -> Result<(), Error> {
let Self { pending_brains, .. } = self;
let mut errors = vec![];
pending_brains.retain_mut(|pending| {
match ffi_v0::poll_load_inference_async(pending.handle) {
Ok(handle) => {
let brain = Brain { handle };
let _ = pending.receiver.send(InferenceInstance {
runtime: self.handle,
inference: brain.handle,
ref_count: 0,
});
false
}
Err(crate::ErrorCode::Unavailable) => true,
Err(error_code) => {
errors.push(Error::from(error_code));
true
}
}
});
if errors.is_empty() {
Ok(())
} else {
Err(Error::Internal)
}
}
}
impl Drop for InferenceRuntime {
fn drop(&mut self) {
if let Err(err) = ffi_v0::destroy_runtime(self.handle) {
log::error!("failed destroying runtime {err:?}");
}
}
}
fn decode_strings(data: Vec<u8>) -> Result<Vec<String>, Error> {
let mut stream = data.into_iter();
let count = stream.next().unwrap();
let mut result = vec![];
for _ in 0..count {
let length = stream.next().unwrap();
let mut bytes = vec![];
for _ in 0..length {
bytes.push(stream.next().unwrap());
}
result.push(String::from_utf8(bytes).map_err(|_ignored| Error::Internal)?);
}
Ok(result)
}
fn decode_u64(data: Vec<u8>) -> Vec<u64> {
let mut stream = data.into_iter();
let count = stream.next().unwrap();
let mut result = vec![];
for _ in 0..count {
let mut bytes = [0u8; 8];
for b in &mut bytes {
*b = stream.next().unwrap();
}
result.push(u64::from_le_bytes(bytes));
}
result
}
fn decode_f32(data: Vec<u8>) -> Vec<f32> {
let mut stream = data.into_iter();
let count = stream.next().unwrap();
let mut result = vec![];
for _ in 0..count {
let mut bytes = [0u8; 4];
for b in &mut bytes {
*b = stream.next().unwrap();
}
result.push(f32::from_le_bytes(bytes));
}
result
}