ark-api-ffi 0.16.0

Ark low-level Wasm FFI API
Documentation
define_api_id!(0xaed6_0109_4f07_3c82, "ml-v3");

pub use super::ml_v1::FixedBufferString;
pub use super::ml_v1::FutureHandle;
pub use super::ml_v1::InferenceHandle;
pub use super::ml_v1::PollVec;
pub use super::ml_v1::TrainingHandle;
use crate::FFIResult;
use bytemuck::Pod;
use bytemuck::Zeroable;

#[repr(C)]
#[derive(Copy, Clone, Debug, Pod, Zeroable)]
pub struct Metric {
    pub name: FixedBufferString<64>,
    pub value: f32,
}

#[allow(clippy::too_many_arguments)]
#[ark_api_macros::ark_bindgen(imports = "ark-ml-v3")]
mod ml {
    use super::*;

    extern "C" {
        /// Starts training.
        ///
        /// The promise outputs a `TrainingHandle`.
        ///
        /// Will start from a checkpoint if the checkpoint parameter is non-empty.
        pub fn start_training(
            hive_url: &str,
            hive_port: u32,
            game_name: &str,
            experiment_name: &str,
            num_features: u32,
            num_actions: u32,
            num_remote_workers: u32,
            config: &str,
            checkpoint: &str,
            training_duration_in_seconds: u64,
        ) -> FFIResult<FutureHandle>;

        /// Pushes an observation with metadata into a queue to be used later.
        ///
        /// The `actions` parameter should only be set for augmented observations and demonstration observations.
        ///
        /// See `submit_training_observations`, `submit_training_augmented_observations` and `submit_training_demonstration_observations`.
        pub fn push_training_observation(
            context: TrainingHandle,
            id: u64,
            // Actually a ml_v1::EpisodeState, but due to ffi limitations we cannot pass that here
            episode_state: u32,
            reward: f32,
            features: &[f32],
            actions: &[f32],
            metadata: &[Metric],
        ) -> FFIResult<()>;

        /// Pushes an observation with metadata into a queue to be used later.
        pub fn push_inference_observation(
            context: InferenceHandle,
            id: u64,
            // Actually a ml_v1::EpisodeState, but due to ffi limitations we cannot pass that here
            episode_state: u32,
            reward: f32,
            features: &[f32],
            metadata: &[Metric],
        ) -> FFIResult<()>;

        /// Downloads metrics asynchronously
        ///
        /// The promise outputs a `Vec<Metric>`.
        pub fn download_metrics(context: TrainingHandle) -> FFIResult<FutureHandle>;

        pub fn poll_future_vec_metric(handle: FutureHandle) -> FFIResult<PollVec>;
        pub fn take_future_vec_metric(handle: FutureHandle, data: &mut [Metric]) -> FFIResult<()>;
    }
}

pub use ml::*;