ark-api-ffi 0.16.0

Ark low-level Wasm FFI API
Documentation
define_api_id!(0xaed6_0109_4f07_3c80, "ml-v1");

use crate::FFIResult;
use bytemuck::CheckedBitPattern;
use bytemuck::NoUninit;
use bytemuck::Pod;
use bytemuck::Zeroable;

pub type TrainingHandle = u64;
pub type InferenceHandle = u64;

/// A handle to a future on the engine-side.
///
/// The handle is untyped because the ffi proc macros do not support generics where we need them.
/// Therefore it is the apis job to keep track of what type the future will actually return and call the corresponding `poll` and `take` methods.
/// See e.g. `ml::take_future_bool`.
pub type FutureHandle = u64;

/// A variable length string that can be up to N bytes long.
#[repr(C)]
#[derive(Clone, Copy, Debug)]
pub struct FixedBufferString<const N: usize> {
    /// Length of the string in bytes
    pub length: u32,
    /// Contents of the string.
    /// Only the first `length` bytes are part of the string.
    /// The remaining bytes are undefined.
    pub bytes: [u8; N],
}

#[allow(unsafe_code)]
// SAFETY: Safe because all fields are Pod, there's no padding, and the generic does not change that. Can't
// be derived due to limitations of derive macro with generics.
unsafe impl<const N: usize> Pod for FixedBufferString<N> {}

#[allow(unsafe_code)]
// SAFETY: Safe because all fields are Zeroable and the generic does not change that. Can't
// be derived due to limitations of derive macro with generics.
unsafe impl<const N: usize> Zeroable for FixedBufferString<N> {}

#[repr(C)]
#[derive(Copy, Clone, Debug, Pod, Zeroable)]
pub struct Metric {
    pub name: [u8; 28],
    pub value: f32,
}

/// Note: needs to be Copy due to wasmtime constraints.
#[repr(C)]
#[derive(Clone, Copy, Debug, NoUninit, CheckedBitPattern)]
pub struct ExperimentInfo {
    /// Human readable name of the experiment.
    ///
    /// Experiment names can be long.
    pub name: FixedBufferString<128>,
    /// Internal id of the experiment
    pub id: FixedBufferString<64>,
    /// Time when experiment was started as nanoseconds since the UNIX epoch.
    pub started_at: i64,
    /// Time when experiment was ended as nanoseconds since the UNIX epoch, or 0 if it has not ended yet.
    pub ended_at: i64,
    /// The experiment will be automatically stopped after this number of seconds
    pub max_duration: u64,
    /// Status of the experiment on the hive server.
    pub experiment_status: ExperimentStatus,
    /// Number of game workers that are started on the server. May be 0 for in-client training.
    pub worker_count: u32,
}

#[allow(unsafe_code)]
// SAFETY: This is *technically* not safe yet with how bytemuck currently defines Zeroable, but since
// we're only using it Zeroable to get `Zeroable::zeroed` (which *is* safe), we're fine. Zeroed memory is fine
// here because `ExperimentStatus` has 0 as a valid value.
unsafe impl Zeroable for ExperimentInfo {}

#[allow(clippy::too_many_arguments, deprecated)] // Deprecated to allow Serialize/Deserialize to touch Tensorflow variant
#[ark_api_macros::ark_bindgen(imports = "ark-ml-v1")]
mod ml {
    use super::*;

    /// Type used to describe a data object created using either `create_data` or by the host.
    #[cfg_attr(feature = "with_serde", derive(serde::Serialize, serde::Deserialize))]
    #[cfg_attr(feature = "with_speedy", derive(speedy::Writable, speedy::Readable))]
    #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
    #[repr(u32)]
    #[non_exhaustive]
    pub enum SnapshotFormat {
        ONNX = 0,

        #[deprecated(note = "Tensorflow format is no longer supported")]
        Tensorflow = 1,

        Cervo = 2,
    }

    #[cfg_attr(feature = "with_serde", derive(serde::Serialize, serde::Deserialize))]
    #[cfg_attr(feature = "with_speedy", derive(speedy::Writable, speedy::Readable))]
    #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
    #[repr(u32)]
    #[non_exhaustive]
    pub enum ExperimentStatus {
        Pending = 0,
        Starting = 1,
        Running = 2,
        Terminating = 3,
        Stopped = 4,
        NotFound = 5,
        Aborted = 6,
        Unhealthy = 7,
    }

    #[derive(Copy, Clone, Default, Debug, NoUninit, CheckedBitPattern)]
    #[repr(C)]
    /// Indicates if a future has been completed.
    /// Should be an enum, but kept as a struct for symmetry with PollVec
    pub struct PollSimple {
        pub ready: bool,
    }

    #[derive(Copy, Clone, Default, Debug, NoUninit, CheckedBitPattern)]
    #[repr(C)]
    /// Indicates if a future that outputs a `Vec<T>` has been completed.
    /// Should be an enum, but the ffi macros don't support that
    pub struct PollVec {
        pub ready: bool,
        pub _pad: [u8; 3],
        /// The length of the output vector.
        pub len: u32,
    }

    #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
    #[repr(u32)]
    pub enum EpisodeState {
        /// Episode is running.
        Running = 0,
        /// The episode has been terminated due to something the agent did.
        ///
        /// The backend may use terminal masking for this state.
        /// In that case the agent will behave as if from this state it cannot receive any more rewards.
        ///
        /// This state is useful if the agent for example has died or in some other way ended up in a non-recoverable state.
        ///
        /// The agent will receive a [`ark_api::Response::EndOfEpisode`](https://ark.embark.dev/api/ark_api/ml/enum.Response.html#variant.EndOfEpisode) when evaluating an observation with this state.
        /// This agent is considered removed from the simulation and you may not send in an observation with the same id again.
        Terminal = 1,
        /// The episode has been terminated due to an external reason that the agent cannot control.
        ///
        /// This is useful for things like timeouts that aren't necessarily due to bad behavior of the agent.
        /// No terminal masking will be used in this case.
        ///
        /// The agent will receive a [`ark_api::Response::EndOfEpisode`](https://ark.embark.dev/api/ark_api/ml/enum.Response.html#variant.EndOfEpisode) when evaluating an observation with this state.
        /// This agent is considered removed from the simulation and you may not send in an observation with the same id again.
        Interrupted = 2,
        /// First frame of the episode.
        ///
        /// The first observation that an agent sends to the engine must always be marked with this episode state.
        Initial = 3,
    }

    extern "C" {
        /// Convert a raw ONNX asset to Cervo format
        ///
        /// Returns the serialized asset.
        pub fn onnx_to_cervo(buffer: &[u8]) -> FFIResult<Vec<u8>>;

        /// Ok(true) if we can connect to hive.
        ///
        /// The promise outputs a `bool`.
        pub fn can_connect_to_hive(hive_url: &str, hive_port: u32) -> FFIResult<FutureHandle>;

        /// Lists all experiments on the machine learning server.
        ///
        /// The promise outputs a `Vec<ExperimentInfo>`.
        pub fn list_experiments(hive_url: &str, hive_port: u32) -> FFIResult<FutureHandle>;

        /// Number of actions that an experiment outputs
        pub fn experiment_action_count(context: TrainingHandle) -> FFIResult<u32>;

        /// Number of features that an experiment takes as input
        pub fn experiment_feature_count(context: TrainingHandle) -> FFIResult<u32>;

        /// Starts training.
        ///
        /// The promise outputs a `TrainingHandle`.
        pub fn start_training(
            hive_url: &str,
            hive_port: u32,
            game_name: &str,
            experiment_name: &str,
            num_features: u32,
            num_actions: u32,
        ) -> FFIResult<FutureHandle>;

        /// Starts training from a checkpoint.
        ///
        /// The promise outputs a `TrainingHandle`.
        pub fn start_training_from_checkpoint(
            hive_url: &str,
            hive_port: u32,
            game_name: &str,
            experiment_name: &str,
            num_features: u32,
            num_actions: u32,
            checkpoint: &str,
        ) -> FFIResult<FutureHandle>;

        /// Pushes an observation into a queue to be used later.
        ///
        /// The `actions` parameter should only be set for augmented observations and demonstration observations.
        ///
        /// See `submit_training_observations`, `submit_training_augmented_observations` and `submit_training_demonstration_observations`.
        pub fn push_training_observation(
            context: TrainingHandle,
            id: u64,
            episode_state: EpisodeState,
            reward: f32,
            features: &[f32],
            actions: &[f32],
            length: f32,
        ) -> FFIResult<()>;

        /// Pushes an observation with metadata into a queue to be used later.
        ///
        /// The `actions` parameter should only be set for augmented observations and demonstration observations.
        ///
        /// See `submit_training_observations`, `submit_training_augmented_observations` and `submit_training_demonstration_observations`.
        pub fn push_training_observation_with_metadata(
            context: TrainingHandle,
            id: u64,
            episode_state: EpisodeState,
            reward: f32,
            features: &[f32],
            actions: &[f32],
            metadata: &[Metric],
        ) -> FFIResult<()>;

        /// Submits observations previously pushed using `push_training_observation`.
        /// `out_actions` should be a slice with length `num_actions * observation count`.
        /// `out_values` should be a slice with length `observation count`.
        ///
        /// Each consecutive sequence of `num_actions` values in `out_actions` corresponds to the actions for an observation.
        /// If an observation was terminal or interrupted then all actions and the value for that observation will be zeroed out.
        pub fn submit_training_observations(
            context: TrainingHandle,
            out_actions: &mut [f32],
            out_values: &mut [f32],
        ) -> FFIResult<()>;

        /// Submits observations previously pushed using `push_training_observation` as augmented observations.
        pub fn submit_training_augmented_observations(context: TrainingHandle) -> FFIResult<()>;

        /// Submits observations previously pushed using `push_training_observation` as demonstration observations.
        pub fn submit_training_demonstration_observations(context: TrainingHandle)
            -> FFIResult<()>;

        /// Stops training with a given handle and destroys it.
        ///
        /// This will only disconnect the current client.
        /// If the experiment has remote workers they will continue to run.
        pub fn stop_training(context: TrainingHandle) -> FFIResult<()>;

        /// Retrieves a trained snapshot.
        ///
        /// The promise outputs a `Vec<u8>`.
        pub fn download_snapshot(context: TrainingHandle) -> FFIResult<FutureHandle>;

        /// Downloads metrics asynchronously
        ///
        /// The promise outputs a `Vec<Metric>`.
        pub fn download_metrics(context: TrainingHandle) -> FFIResult<FutureHandle>;

        pub fn start_inference(
            num_features: u32,
            num_actions: u32,
            snapshot_data: &[u8],
            snapshot_format: SnapshotFormat,
        ) -> FFIResult<InferenceHandle>;

        /// Submits observations previously pushed using `push_inference_observation`.
        /// `out_actions` should be a slice with length `num_actions * observation count`.
        /// `out_values` should be a slice with length `observation count`.
        ///
        /// Each consecutive sequence of `num_actions` values in `out_actions` corresponds to the actions for an observation.
        /// If an observation was terminal or interrupted then all actions and the value for that observation will be zeroed out.
        pub fn submit_inference_observations(
            context: InferenceHandle,
            out_actions: &mut [f32],
            out_values: &mut [f32],
        ) -> FFIResult<()>;

        /// Pushes an observation into a queue to be used later.
        pub fn push_inference_observation(
            context: InferenceHandle,
            id: u64,
            episode_state: EpisodeState,
            reward: f32,
            features: &[f32],
            length: f32,
        ) -> FFIResult<()>;

        /// Pushes an observation with metadata into a queue to be used later.
        pub fn push_inference_observation_with_metadata(
            context: InferenceHandle,
            id: u64,
            episode_state: EpisodeState,
            reward: f32,
            features: &[f32],
            metadata: &[Metric],
        ) -> FFIResult<()>;

        pub fn stop_inference(context: InferenceHandle) -> FFIResult<()>;

        /// Polls a future that outputs a boolean
        pub fn poll_future_bool(handle: FutureHandle) -> FFIResult<PollSimple>;

        /// Transfers the data of a completed future to the module side.
        ///
        /// If this method returns an error the `data` may not have been assigned.
        /// Calling this method consumes the handle and calling it again will return an error.
        pub fn take_future_bool(handle: FutureHandle, data: &mut bool) -> FFIResult<()>;

        pub fn poll_future_vec_u8(handle: FutureHandle) -> FFIResult<PollVec>;
        pub fn take_future_vec_u8(handle: FutureHandle, data: &mut [u8]) -> FFIResult<()>;

        pub fn poll_future_vec_metric(handle: FutureHandle) -> FFIResult<PollVec>;
        pub fn take_future_vec_metric(handle: FutureHandle, data: &mut [Metric]) -> FFIResult<()>;

        pub fn poll_future_training_handle(handle: FutureHandle) -> FFIResult<PollSimple>;
        pub fn take_future_training_handle(
            handle: FutureHandle,
            data: &mut TrainingHandle,
        ) -> FFIResult<()>;

        pub fn poll_future_vec_experiment_info(handle: FutureHandle) -> FFIResult<PollVec>;
        pub fn take_future_vec_experiment_info(
            handle: FutureHandle,
            data: &mut [ExperimentInfo],
        ) -> FFIResult<()>;

        /// Drops a future handle.
        ///
        /// This must be called if you do not call `take_*` to avoid a memory leak.
        pub fn drop_future(handle: FutureHandle) -> FFIResult<()>;
    }
}

pub use ml::*;