ark-api 0.17.0-pre.15

Ark API
Documentation
//! # 🧠 Machine Learning Inference API
//!
//! Machine Learning API for Runtime Inference.
//!
//! Supports loading and inferring brains from game-code with automatic batching and runtime scaling. This leads to
//! overall better performance than using the legacy inference support in the regular ML API, while saving memory and
//! reducing bookkeeping for other components.

#![allow(missing_docs)]

pub use crate::ffi::ml_v1::SnapshotFormat;
use crate::{ffi::ml_inference_v0 as ffi_v0, Error};
use ffi_v0::AgentId;
use ffi_v0::LoadInferenceHandle;
use ffi_v0::RawInferenceHandle;
use ffi_v0::RuntimeHandle;
#[doc(hidden)]
pub use ffi_v0::API as FFI_API;
use std::collections::HashMap;
use crossbeam_channel::{unbounded, Receiver, Sender};
use std::task::Poll;

pub type AgentData = HashMap<String, Vec<f32>>;

#[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)]
pub struct BrainId(ffi_v0::RawInferenceHandle);

/// A future resolving to an inference handle when loading is completed.
pub struct InferenceFuture {
    recv: Receiver<InferenceInstance>,
}

impl std::future::Future for InferenceFuture {
    type Output = Result<InferenceInstance, Error>;

    fn poll(
        self: std::pin::Pin<&mut Self>,
        _cx: &mut std::task::Context<'_>,
    ) -> Poll<Self::Output> {
        match self.recv.try_recv() {
            Ok(handle) => Poll::Ready(Ok(handle)),
            Err(crossbeam_channel::TryRecvError::Empty) => Poll::Pending,
            Err(crossbeam_channel::TryRecvError::Disconnected) => {
                Poll::Ready(Err(Error::Unavailable))
            }
        }
    }
}

struct PendingBrain {
    handle: LoadInferenceHandle,
    receiver: Sender<InferenceInstance>,
}

struct Brain {
    handle: RawInferenceHandle,
}

/// The Inference Runtime is responsible for managing the lifetime of multiple inference models,
/// from load to unload.
pub struct InferenceRuntime {
    handle: RuntimeHandle,
    pending_brains: Vec<PendingBrain>,
}

/// A handle to a specific inference model. While it is safe for this to outlive the
/// [`InferenceRuntime`] that provided it, it will stop working if that happens and has to be loaded
/// again.
#[derive(Debug, Clone, Hash, Eq, PartialEq, Copy)]
#[cfg_attr(feature = "with_speedy", derive(speedy::Writable, speedy::Readable))]
pub struct InferenceHandle {
    runtime: RuntimeHandle,
    inference: RawInferenceHandle,
}

pub struct InferenceInstance {
    ref_count: usize,
    runtime: RuntimeHandle,
    inference: RawInferenceHandle,
}

impl Drop for InferenceInstance {
    fn drop(&mut self) {
        if self.ref_count != 0 {
            log::error!("inference instance being dropped with existing references: {:?}", self.ref_count);
        }

        if let Err(e) = ffi_v0::unload_inference(self.runtime, self.inference) {
            log::error!("failed unloading inference: {e:?}");
        }
    }
}

impl InferenceInstance {
    pub fn handle(&self) -> InferenceHandle {
        InferenceHandle {
            runtime: self.runtime,
            inference: self.inference,
        }
    }

    pub fn increment_refcount(&mut self) -> usize {
        self.ref_count += 1;
        self.ref_count
    }

    pub fn decrease_refcount(&mut self) -> usize {
        self.ref_count -= 1;
        self.ref_count
    }

    /// Push data to the next inference batch.
    pub fn submit_infer(
        &mut self,
        agent_id: AgentId,
        states: &HashMap<String, Vec<f32>>,
    ) -> Result<(), Error> {
        ffi_v0::begin_infer_submission(self.runtime, self.inference, agent_id)?;
        for (k, v) in states {
            ffi_v0::submit_inference_item(self.runtime, self.inference, k, v)
                .map_err(Error::from)?;
        }
        ffi_v0::end_infer_submission(self.runtime, self.inference).map_err(Error::from)
    }

    /// Check if there is any result to retrieve for this call at this time.
    pub fn poll_results(&self) -> Result<Vec<(AgentId, AgentData)>, Error> {
        let mut output = vec![];
        let brain_agents = decode_u64(ffi_v0::result_get_brain_agents(
            self.runtime,
            self.inference,
        )?);

        let brain_outputs: Vec<String> =
            ffi_v0::result_get_brain_output_names(self.runtime, self.inference)
                .map_err(Error::from)
                .and_then(decode_strings)?;

        for agent_id in brain_agents {
            let mut agent_data = HashMap::default();
            for (idx, name) in brain_outputs.iter().enumerate() {
                agent_data.insert(
                    name.clone(),
                    decode_f32(ffi_v0::result_get_brain_agent_output(
                        self.runtime,
                        self.inference,
                        agent_id,
                        idx as u32,
                    )?),
                );
            }

            output.push((agent_id, agent_data));
        }

        Ok(output)
    }
}

impl Default for InferenceRuntime {
    fn default() -> Self {
        Self::create()
    }
}

impl InferenceRuntime {
    pub fn create() -> Self {
        let handle = ffi_v0::create_runtime();
        Self {
            handle,
            pending_brains: vec![],
        }
    }

    pub fn execute_inference(&self, limit_s: Option<f32>) -> Result<Vec<u64>, Error> {
        let agents_bytes = if let Some(limit) = limit_s {
            ffi_v0::execute_inference_for(self.handle, limit).map_err(Error::from)
        } else {
            ffi_v0::execute_inference(self.handle).map_err(Error::from)
        }?;

        let agent_ids = decode_u64(agents_bytes);
        Ok(agent_ids)
    }

    pub fn load(&mut self, format: SnapshotFormat, data: &[u8]) -> Result<InferenceFuture, Error> {
        let (tx, rx) = unbounded();

        let handle = ffi_v0::load_inference_async(self.handle, format as u32, data)?;
        self.pending_brains.push(PendingBrain {
            handle,
            receiver: tx,
        });
        Ok(InferenceFuture { recv: rx })
    }

    pub fn update(&mut self) -> Result<(), Error> {
        let Self { pending_brains, .. } = self;
        let mut errors = vec![];

        pending_brains.retain_mut(|pending| {
            match ffi_v0::poll_load_inference_async(pending.handle) {
                Ok(handle) => {
                    let brain = Brain { handle };

                    let _ = pending.receiver.send(InferenceInstance {
                        runtime: self.handle,
                        inference: brain.handle,
                        ref_count: 0,
                    });
                    false
                }
                Err(crate::ErrorCode::Unavailable) => true,
                Err(error_code) => {
                    errors.push(Error::from(error_code));
                    true
                }
            }
        });

        if errors.is_empty() {
            Ok(())
        } else {
            Err(Error::Internal)
        }
    }
}

impl Drop for InferenceRuntime {
    fn drop(&mut self) {
        // Should we check and warn if we have existing users?

        // This can't take the arc because drop is &mut, so can't move out fields.
        if let Err(err) = ffi_v0::destroy_runtime(self.handle) {
            log::error!("failed destroying runtime {err:?}");
        }
    }
}

fn decode_strings(data: Vec<u8>) -> Result<Vec<String>, Error> {
    let mut stream = data.into_iter();
    let count = stream.next().unwrap();

    let mut result = vec![];
    for _ in 0..count {
        let length = stream.next().unwrap();
        let mut bytes = vec![];
        for _ in 0..length {
            bytes.push(stream.next().unwrap());
        }

        result.push(String::from_utf8(bytes).map_err(|_ignored| Error::Internal)?);
    }

    Ok(result)
}

fn decode_u64(data: Vec<u8>) -> Vec<u64> {
    let mut stream = data.into_iter();
    let count = stream.next().unwrap();

    let mut result = vec![];
    for _ in 0..count {
        let mut bytes = [0u8; 8];
        for b in &mut bytes {
            *b = stream.next().unwrap();
        }

        result.push(u64::from_le_bytes(bytes));
    }

    result
}

fn decode_f32(data: Vec<u8>) -> Vec<f32> {
    let mut stream = data.into_iter();
    let count = stream.next().unwrap();

    let mut result = vec![];
    for _ in 0..count {
        let mut bytes = [0u8; 4];
        for b in &mut bytes {
            *b = stream.next().unwrap();
        }

        result.push(f32::from_le_bytes(bytes));
    }

    result
}