ark-api 0.17.0-pre.15

Ark API
Documentation
//! # 👾 Machine Learning API
//!
//! API for machine learning - more specifically, Hive-style reinforcement learning.
//!
//! Supports training/inferring through Hive and inferring through Tract, a Rust inference framework.
//!
//! The inferring support in this module is being deprecated for gameplay use. using [`super::ml-inference`] is better
//! for performance and simplified code.

#![allow(missing_docs)]

mod ffi {
    pub use crate::ffi::{ml_v1 as v1, ml_v1::*};
    pub use crate::ffi::{ml_v2 as v2, ml_v2::*};
    pub use crate::ffi::{ml_v3 as v3, ml_v3::*};
    pub use crate::ffi::{ml_v4 as v4, ml_v4::*};
    pub use crate::ffi::{ml_v5 as v5, ml_v5::*};
}

use crate::ffi::{ErrorCode, FFIResult};
use crate::Error;
use bytemuck::Zeroable;
pub use ffi::{
    EpisodeState, ExperimentStatus, FixedBufferString, FutureHandle, HardwareType, InferenceHandle,
    SnapshotFormat, TrainingHandle,
};
use std::time::Duration;
use std::{future::Future, marker::PhantomData, task::Poll};

#[doc(hidden)]
pub use ffi::v4::API as FFI_API;

#[cfg_attr(feature = "with_serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "with_speedy", derive(speedy::Writable, speedy::Readable))]
#[derive(Clone)]
pub enum Response {
    EndOfEpisode,
    Actions(Actions),
}

#[cfg_attr(feature = "with_serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "with_speedy", derive(speedy::Writable, speedy::Readable))]
#[derive(Clone)]
pub struct Actions {
    pub actions: Vec<f32>,
    pub value: f32,
}

#[cfg_attr(feature = "with_serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "with_speedy", derive(speedy::Writable, speedy::Readable))]
#[derive(Clone, Debug)]
pub struct Metric {
    pub name: String,
    pub value: f32,
}

#[cfg_attr(feature = "with_serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "with_speedy", derive(speedy::Writable, speedy::Readable))]
#[derive(Clone, Debug)]
pub struct ExperimentInfo {
    pub id: String,
    pub name: String,
    #[cfg(feature = "time")]
    /// Time when experiment was started.
    ///
    /// This field is only available when the `time` api is enabled.
    pub started_at: crate::time::Instant,
    #[cfg(feature = "time")]
    /// Time when experiment was ended.
    ///
    /// This field is only available when the `time` api is enabled.
    pub ended_at: Option<crate::time::Instant>,
    /// The experiment will be automatically stopped after this number of seconds
    pub max_duration: u64,
    /// Status of the experiment on the hive server.
    pub experiment_status: ExperimentStatus,
    /// Number of game workers that are started on the server. May be 0 for in-client training.
    pub worker_count: u32,
}

#[derive(Clone, Debug)]
/// Settings used to construct the cloud worker when training in the cloud.
pub struct CloudWorkerSettings {
    /// The module id to run when using cloud workers.
    pub module_id: String,

    /// The number of workers to use in the cloud.
    pub worker_count: u8,
}

#[derive(Clone, Debug)]
pub struct StartExperimentSettings {
    /// The name of the trial
    pub trial_name: String,

    /// The experiment the trial belongs to
    pub experiment_name: String,

    /// The configuration to train with
    pub configuration: String,

    /// The duration to train for (at most).
    pub duration: Duration,

    /// The hardware class to use (in the cloud).
    pub hardware: HardwareType,

    /// Settings for the cloud worker, if used
    pub cloud_worker: Option<CloudWorkerSettings>,
}

#[derive(Clone, PartialEq, Debug)]
pub struct ProtocolConfig {
    pub feature_count: u32,
    pub action_count: u32,
    pub hidden: u32,
    pub alpha: f32,
    pub use_terminal_masking: bool,
    pub learning_rate_init: f32,
    pub learning_rate_end: f32,
    pub learning_rate_steps: u32,
    pub batch_size: u32,
    pub memory_min_size: u32,
    pub memory_max_size: u32,
    pub gamma: f32,
    pub rollout_length: u32,
}

impl Default for ProtocolConfig {
    fn default() -> Self {
        Self {
            feature_count: 0,
            action_count: 0,
            hidden: 1024,
            alpha: 0.01,
            use_terminal_masking: true,
            learning_rate_init: 1e-2,
            learning_rate_end: 1e-4,
            learning_rate_steps: 3e5 as u32,
            batch_size: 4000,
            memory_min_size: 5000,
            memory_max_size: 1e6 as u32,
            gamma: 0.97,
            rollout_length: 5,
        }
    }
}

impl ProtocolConfig {
    fn to_ffi(&self) -> ffi::ProtocolConfig {
        ffi::ProtocolConfig {
            feature_count: self.feature_count,
            action_count: self.action_count,
            hidden: self.hidden,
            alpha: self.alpha,
            use_terminal_masking: u32::from(self.use_terminal_masking),
            learning_rate_init: self.learning_rate_init,
            learning_rate_end: self.learning_rate_end,
            learning_rate_steps: self.learning_rate_steps,
            batch_size: self.batch_size,
            memory_min_size: self.memory_min_size,
            memory_max_size: self.memory_max_size,
            gamma: self.gamma,
            rollout_length: self.rollout_length,
        }
    }
}

impl Metric {
    pub fn convert_to_ffi_safe(&self) -> Result<ffi::v3::Metric, FixedBufferFromStringError> {
        Ok(ffi::v3::Metric {
            name: convert_string_to_fixed_buffer(&self.name)?,
            value: self.value,
        })
    }
}

#[derive(Debug, Copy, Clone)]
pub enum FixedBufferFromStringError {
    LengthTooLarge { found: usize, expected_max: usize },
}

impl std::error::Error for FixedBufferFromStringError {}

impl std::fmt::Display for FixedBufferFromStringError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            FixedBufferFromStringError::LengthTooLarge {
                found,
                expected_max,
            } => {
                write!(f, "The string has a length which is larger than what will fit in the buffer. Found string of length {found} but can only fit strings of length at most {expected_max}" )
            }
        }
    }
}

fn convert_string_to_fixed_buffer<const N: usize>(
    string: &str,
) -> Result<ffi::FixedBufferString<N>, FixedBufferFromStringError> {
    if string.len() > N {
        Err(FixedBufferFromStringError::LengthTooLarge {
            found: string.len(),
            expected_max: N,
        })
    } else {
        let mut bytes: [u8; N] = [0; N];
        let str_data = string.as_bytes();

        // We have already verified that the length is OK, so the slicing cannot panic
        #[allow(clippy::indexing_slicing)]
        bytes[..str_data.len()].copy_from_slice(str_data);

        Ok(ffi::FixedBufferString {
            length: str_data.len() as u32,
            bytes,
        })
    }
}

#[derive(Clone, Debug)]
struct InferenceResponse {
    _id: u64,
    actions: Option<Vec<f32>>,
}

fn decode_inference_results(data: Vec<u8>) -> Result<Vec<InferenceResponse>, Error> {
    let mut decoded = vec![];
    let mut offset = 0;

    let version = data[0]; // from_ne
    if version != 0 {
        return Err(Error::ApiNotAvailable); // this code doesn't know how to handle the version in the data.
    }

    let response_count = data[1]; // from_ne
    let action_count = u16::from_ne_bytes(data[2..4].try_into().map_err(|_err| Error::Internal)?);
    offset += 4;

    for _ in 0..response_count {
        let id = u64::from_ne_bytes(
            data[offset..offset + 8]
                .try_into()
                .map_err(|_err| Error::Internal)?,
        );
        offset += 8;

        let has_data = data[offset];
        offset += 1;

        let actions = if has_data != 0 {
            let action_bytes = (action_count * 4) as usize;
            let action_u8s = &data[offset..offset + action_bytes];
            let actions = action_u8s
                .chunks_exact(4)
                .map(|bytes| {
                    Ok(f32::from_ne_bytes(
                        bytes.try_into().map_err(|_err| Error::Internal)?,
                    ))
                })
                .collect::<Result<Vec<f32>, Error>>()?;

            offset += action_bytes;

            Some(actions)
        } else {
            None
        };
        decoded.push(InferenceResponse { _id: id, actions });
    }

    Ok(decoded)
}

#[derive(Debug, Clone)]
pub enum FixedBufferToStringError {
    LengthTooLarge { found: u32, expected_max: usize },
    InvalidUtf8(std::string::FromUtf8Error),
}

impl std::error::Error for FixedBufferToStringError {
    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
        match self {
            FixedBufferToStringError::InvalidUtf8(e) => Some(e),
            FixedBufferToStringError::LengthTooLarge { .. } => None,
        }
    }
}

impl std::fmt::Display for FixedBufferToStringError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            FixedBufferToStringError::LengthTooLarge {
                found,
                expected_max,
            } => {
                write!(f, "The buffer has a length which is larger than the actual buffer. Found length {found} but expected a length of at most {expected_max}")
            }
            FixedBufferToStringError::InvalidUtf8(e) => {
                write!(f, "The string did not contain valid utf8 data: {e}")
            }
        }
    }
}

fn convert_fixed_buffer_to_string<const N: usize>(
    buffer: &FixedBufferString<N>,
) -> Result<String, FixedBufferToStringError> {
    if buffer.length as usize > N {
        Err(FixedBufferToStringError::LengthTooLarge {
            found: buffer.length,
            expected_max: N,
        })
    } else {
        // We have already verified that the length is OK, so the slicing cannot panic
        String::from_utf8(buffer.bytes[0..buffer.length as usize].to_vec())
            .map_err(FixedBufferToStringError::InvalidUtf8)
    }
}

/// A machine learning context. Takes batches of observations, and produces
/// corresponding batches of actions.
pub struct Experiment {
    pub(crate) ctx: TrainingHandle,
}

pub struct ArkExperimentConfig {
    pub num_features: u32,
    pub num_actions: u32,
}

pub struct ExperimentConfig {
    /// Additional configuration options passed to the `start_training` call.
    pub ark_config: ArkExperimentConfig,
    /// The custom configuration passed to the `start_training` call.
    pub module_config: String,
}

impl ExperimentConfig {
    pub(crate) async fn from_ffi_async(
        future: Result<ffi::ExperimentConfigFuture, ErrorCode>,
    ) -> Result<Self, Error> {
        let config = future.map_err(|_err| Error::Internal)?;
        let ark_config = MLFuture::<ffi::ArkExperimentConfig>::new(Ok(config.ark_config))
            .await
            .map_err(|_err| Error::Internal)?;
        let module_config_bytes = MLFuture::<Vec<u8>>::new(Ok(config.module_config))
            .await
            .map_err(|_err| Error::Internal)?;
        let module_config =
            String::from_utf8(module_config_bytes).map_err(|_err| Error::Internal)?;

        Ok(ExperimentConfig {
            ark_config: ArkExperimentConfig {
                num_features: ark_config.num_features,
                num_actions: ark_config.num_actions,
            },
            module_config,
        })
    }
}

///
#[inline]
pub fn onnx_to_cervo(buffer: &[u8]) -> Result<Vec<u8>, Error> {
    ffi::onnx_to_cervo(buffer).map_err(Error::from)
}

/// Ok(true) if we can connect to hive
#[inline]
pub fn can_connect_to_hive(host: &HiveHost) -> impl Future<Output = Result<bool, Error>> {
    let (hive_url, hive_port) = host.to_url_port();
    MLFuture::<bool>::new(ffi::can_connect_to_hive(hive_url, hive_port))
}

#[inline]
pub fn set_worker_module_link(cid: &str) -> Result<(), Error> {
    ffi::set_worker_module_link(cid).map_err(Error::from)
}

pub fn experiment_config_from_registry(
    host: &HiveHost,
    run_id: &str,
) -> impl Future<Output = Result<ExperimentConfig, Error>> {
    let (hive_url, hive_port) = host.to_url_port();
    ExperimentConfig::from_ffi_async(ffi::experiment_config_from_registry(
        hive_url, hive_port, run_id,
    ))
}

pub fn config_from_registry(
    host: &HiveHost,
    run_id: &str,
) -> impl Future<Output = Result<String, Error>> {
    let (hive_url, hive_port) = host.to_url_port();
    let fut = ffi::v4::raw_experiment_config_from_registry(hive_url, hive_port, run_id);
    async move {
        MLFuture::<String>::new(fut)
            .await
            .map_err(|_err| Error::Internal)
    }
}

pub fn snapshot_from_registry(
    host: &HiveHost,
    run_id: &str,
) -> impl Future<Output = Result<Vec<u8>, Error>> {
    let (hive_url, hive_port) = host.to_url_port();
    MLFuture::<Vec<u8>>::new(ffi::download_snapshot_from_registry(
        hive_url, hive_port, run_id,
    ))
}

/// Lists all running experiments on the hive server.
///
/// You can connect to a running experiment using `Experiment::connect_to_experiment`.
pub fn list_experiments(
    host: &HiveHost,
) -> impl Future<Output = Result<Vec<ExperimentInfo>, Error>> {
    let (hive_url, hive_port) = host.to_url_port();
    let future =
        MLFuture::<Vec<ffi::ExperimentInfo>>::new(ffi::list_experiments(hive_url, hive_port));
    async move {
        let experiments = future.await?;

        Ok(experiments
            .iter()
            .map(|experiment| ExperimentInfo {
                // The engine should be giving us valid strings, so we unwrap here.
                // If the engine doesn't give us valid strings then something is very wrong.
                id: convert_fixed_buffer_to_string(&experiment.id).unwrap(),
                name: convert_fixed_buffer_to_string(&experiment.name).unwrap(),
                #[cfg(feature = "time")]
                started_at: crate::api::time::Instant::from_nanos_since_epoch(
                    experiment.started_at,
                ),
                #[cfg(feature = "time")]
                ended_at: if experiment.ended_at != 0 {
                    Some(crate::api::time::Instant::from_nanos_since_epoch(
                        experiment.ended_at,
                    ))
                } else {
                    None
                },
                max_duration: experiment.max_duration,
                experiment_status: experiment.experiment_status,
                worker_count: experiment.worker_count,
            })
            .collect())
    }
}

/// Future representing an asynchronous task on the engine side.
///
/// The future outputs `Result<T, Error>`.
///
/// TODO: This future, like most other futures that ark exposes to modules, does not adhere to correct future behavior.
/// In particular if you poll it and it returns Pending, it will *not* wake itself up when it is ready to do work again.
/// Instead it assumes it will be polled regularly until it is done.
#[allow(clippy::upper_case_acronyms)]
struct MLFuture<T> {
    handle: Result<FutureHandle, Error>,
    _phantom: PhantomData<T>,
}

impl<T> MLFuture<T> {
    /// Create a new future, wrapping a handle received from the engine.
    ///
    /// The call to generate the handle may have failed, in which case you can pass a handle with an `Err` value.
    fn new(handle: Result<FutureHandle, ErrorCode>) -> Self {
        Self {
            handle: handle.map_err(Error::from),
            _phantom: Default::default(),
        }
    }
}

impl Experiment {
    /// Start training with a new experiment.
    ///
    /// `host` where the hive server exists.
    /// `game_name` identifier for a type of game, e.g. 'pong' or 'moreau-arena'.
    /// `experiment_name` identifier for this experiment. This can be any arbitrary data. If the server already has an experiment with that name then a number will be appended to make it unique.
    /// `num_features` number of floating point values the neural network takes as input in each observation.
    /// `num_actions` number of floating point values the neural network outputs for each observation.
    /// `num_remote_workers` number of workers that should  be started in the cloud.
    ///      Using a non-zero number is only supported when using a cloud `host`.
    /// `config` is a piece of data that contains all the necessary information to re-create the training run.
    ///      This is not used anywhere internally, but when connecting to an existing experiment you can query it using the `Experiment::config` method.
    ///      The only constraint on this data is that it must be valid json.
    ///      This is very useful for remote workers as they can download this configuration.
    /// `checkpoint` optional reference to checkpoint data to load from. TODO: Api for listing checkpoints.
    /// `training_duration_in_seconds` Number of seconds that the training will run for. The hive server will stop the experiment after this amount of time.
    ///      This is primarily useful when training in the cloud with remote workers,
    //       since otherwise it's easier to just let the experiment continue until the `Experiment` struct is dropped (which will also stop the hive server when no remote workers are used)
    #[allow(clippy::too_many_arguments)]
    pub fn new(
        host: &HiveHost,
        game_name: &str,
        experiment_name: &str,
        num_remote_workers: u32,
        config: &str,
        checkpoint: Option<&str>,
        training_duration_in_seconds: u64,
        protocol: &ProtocolConfig,
    ) -> impl Future<Output = Result<Self, Error>> {
        let (hive_url, hive_port) = host.to_url_port();
        let handle = ffi::v4::start_training(
            hive_url,
            hive_port,
            game_name,
            experiment_name,
            num_remote_workers,
            config,
            checkpoint.unwrap_or(""),
            training_duration_in_seconds,
            &protocol.to_ffi(),
        );

        let future = MLFuture::<TrainingHandle>::new(handle);
        async move { Ok(Experiment { ctx: future.await? }) }
    }

    /// Start training with a new experiment.
    ///
    /// `host` where the hive server exists.
    /// `settings` all data required to spawn the experiment
    #[allow(clippy::too_many_arguments)]
    pub async fn new_from_settings(
        host: HiveHost,
        settings: StartExperimentSettings,
    ) -> Result<Self, Error> {
        let (hive_url, hive_port) = host.to_url_port();

        let (module_id, worker_count) = settings
            .cloud_worker
            .as_ref()
            .map_or(("", 0), |s| (s.module_id.as_str(), s.worker_count));

        let handle = ffi::v5::start_training(
            hive_url,
            hive_port,
            &settings.trial_name,
            &settings.configuration,
            &settings.experiment_name,
            module_id,
            settings.duration.as_secs() as u32,
            u32::from(worker_count),
            settings.hardware,
        );

        let ctx = MLFuture::<TrainingHandle>::new(handle).await?;
        Ok(Experiment { ctx })
    }

    /// Connects to an already running experiment
    pub fn connect_to_experiment(
        host: &HiveHost,
        run_id: &str,
    ) -> impl Future<Output = Result<Self, Error>> {
        let (hive_url, hive_port) = host.to_url_port();
        let handle = ffi::connect_to_experiment(hive_url, hive_port, run_id);
        let future = MLFuture::<TrainingHandle>::new(handle);

        async move {
            let ctx = future.await?;

            Ok(Experiment { ctx })
        }
    }

    /// Retrieves the configuration for the experiment
    ///
    /// This may require a download from the hive server.
    pub fn config(&self) -> impl Future<Output = Result<ExperimentConfig, Error>> {
        ExperimentConfig::from_ffi_async(ffi::experiment_config(self.ctx))
    }

    /// Retrieves the configuration for the experiment
    ///
    /// This may require a download from the hive server.
    pub fn raw_config(&self) -> impl Future<Output = Result<String, Error>> {
        let fut = ffi::v5::raw_experiment_config(self.ctx);
        async move {
            MLFuture::<String>::new(fut)
                .await
                .map_err(|_err| Error::Internal)
        }
    }

    /// Retrieves current training metrics for Hive
    pub fn metrics(&self) -> impl Future<Output = Result<Vec<(String, f32)>, Error>> {
        let result = MLFuture::<Vec<ffi::v3::Metric>>::new(ffi::v3::download_metrics(self.ctx));
        async move {
            match result.await {
                Ok(metrics) => Ok(metrics
                    .into_iter()
                    .map(|metric| {
                        let name = convert_fixed_buffer_to_string(&metric.name)?;
                        Ok((name, metric.value))
                    })
                    .collect::<Result<Vec<_>, FixedBufferToStringError>>()
                    .map_err(|_e| Error::InvalidArguments)?),
                Err(err) => Err(err),
            }
        }
    }

    /// Retrieves a trained snapshot.
    pub fn snapshot(&self) -> impl Future<Output = Result<Vec<u8>, Error>> {
        MLFuture::<Vec<u8>>::new(ffi::download_snapshot(self.ctx))
    }

    /// Trains using the current batch of observations.
    pub fn push_training_experiences(
        &self,
        observations: &[Observation<'_>],
    ) -> Result<Vec<Response>, Error> {
        for observation in observations {
            ffi::v3::push_training_observation(
                self.ctx,
                observation.id,
                observation.episode_state as u32,
                observation.reward,
                observation.features,
                &[],
                &(observation
                    .metadata
                    .iter()
                    .map(Metric::convert_to_ffi_safe)
                    .collect::<Result<Vec<_>, _>>()
                    .map_err(|_e| Error::InvalidArguments)?),
            )
            .map_err(Error::from)?;
        }

        let response_bytes =
            ffi::v5::submit_training_observations(self.ctx).map_err(Error::from)?;
        let actions = decode_inference_results(response_bytes)?;

        let mut responses = vec![Response::EndOfEpisode; observations.len()];

        for ((observation, response), inference_result) in
            observations.iter().zip(&mut responses).zip(actions)
        {
            if matches!(
                observation.episode_state,
                EpisodeState::Initial | EpisodeState::Running
            ) {
                *response = Response::Actions(Actions {
                    actions: inference_result.actions.ok_or(Error::Internal)?,
                    value: 0.0,
                });
            }
        }

        Ok(responses)
    }

    /// Send demonstration experiences recorded in ark to Hive.
    /// Will return an error if the Hive protocol doesn't support imitation learning
    pub fn push_demonstration_experiences(
        &self,
        demonstrations: &[Demonstration<'_>],
    ) -> Result<(), Error> {
        for demonstration in demonstrations {
            ffi::v3::push_training_observation(
                self.ctx,
                // demonstration.timestep,
                demonstration.observation.id,
                demonstration.observation.episode_state as u32,
                demonstration.observation.reward,
                demonstration.observation.features,
                demonstration.actions,
                &metadata_to_ffi(demonstration.observation.metadata)?,
            )
            .map_err(Error::from)?;
        }

        ffi::submit_training_demonstration_observations(self.ctx).map_err(Error::from)
    }

    /// Send augmented experiences augmented in ark to Hive.
    ///
    /// The list of observations represents one timestep for a set of agents.
    /// You should not include multiple observations for a single agent, call this method multiple times instead.
    pub fn push_augmented_experiences(
        &self,
        observations: &[AugmentedObservation<'_>],
    ) -> Result<(), Error> {
        for observation in observations {
            ffi::v3::push_training_observation(
                self.ctx,
                observation.observation.id,
                observation.observation.episode_state as u32,
                observation.observation.reward,
                observation.observation.features,
                observation.actions,
                &metadata_to_ffi(observation.observation.metadata)?,
            )
            .map_err(Error::from)?;
        }
        ffi::submit_training_augmented_observations(self.ctx).map_err(Error::from)
    }

    /// Stops an experiment and all associated workers.
    ///
    /// If you just want to disconnect this client (without stopping potential remote workers) then simply `drop` the experiment.
    pub fn stop_experiment(self) -> Result<(), Error> {
        ffi::stop_experiment(self.ctx).map_err(Error::from)
    }
}

impl Drop for Experiment {
    fn drop(&mut self) {
        let _ = ffi::stop_training(self.ctx);
    }
}

pub struct Inference {
    pub(crate) ctx: InferenceHandle,
}

fn metadata_to_ffi(metadata: &[Metric]) -> Result<Vec<ffi::v3::Metric>, Error> {
    let res = metadata
        .iter()
        .map(Metric::convert_to_ffi_safe)
        .collect::<Result<Vec<_>, _>>()
        .map_err(|_e| Error::InvalidArguments)?;
    Ok(res)
}

impl Inference {
    pub fn new(snapshot_data: &[u8], snapshot_format: SnapshotFormat) -> Result<Self, Error> {
        ffi::v5::start_inference(snapshot_data, snapshot_format as u32)
            .map(|ctx| Self { ctx })
            .map_err(Error::from)
    }

    /// Evaluates actions from the current batch of observations. Ignores `episode_state`, `reward`.
    /// Does not contribute to training.
    pub fn evaluate(&self, observations: &[Observation<'_>]) -> Result<Vec<Response>, Error> {
        for observation in observations {
            ffi::v3::push_inference_observation(
                self.ctx,
                observation.id,
                observation.episode_state as u32,
                observation.reward,
                observation.features,
                &metadata_to_ffi(observation.metadata)?,
            )
            .map_err(Error::from)?;
        }

        let response_bytes =
            ffi::v5::submit_inference_observations(self.ctx).map_err(Error::from)?;
        let actions = decode_inference_results(response_bytes)?;

        let mut responses = vec![Response::EndOfEpisode; observations.len()];

        for ((observation, response), actions) in
            observations.iter().zip(&mut responses).zip(actions)
        {
            if matches!(
                observation.episode_state,
                EpisodeState::Initial | EpisodeState::Running
            ) {
                *response = Response::Actions(Actions {
                    actions: actions.actions.unwrap(),
                    value: 0.0,
                });
            }
        }

        Ok(responses)
    }
}

impl Drop for Inference {
    /// Stops the tract inference context. After calling this, you can no longer call `infer`
    /// until you call `start_inference` again.
    fn drop(&mut self) {
        if let Err(err) = ffi::stop_inference(self.ctx) {
            log::error!("{:?}", err);
        }
    }
}

#[derive(Clone)]
pub struct Observation<'a> {
    pub id: u64,
    pub episode_state: EpisodeState,
    pub reward: f32,
    pub features: &'a [f32],
    pub metadata: &'a [Metric],
}

pub struct Demonstration<'a> {
    pub timestep: u64,
    pub actions: &'a [f32],
    pub observation: Observation<'a>,
}

pub struct AugmentedObservation<'a> {
    pub observation: Observation<'a>,
    pub actions: &'a [f32],
}

#[cfg_attr(feature = "with_serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "with_speedy", derive(speedy::Writable, speedy::Readable))]
#[derive(Clone, Debug, PartialEq)]
pub enum HiveHost {
    Local,
    Cloud,
    Custom { host: String, port: u32 },
}

impl HiveHost {
    pub(crate) fn to_url_port(&self) -> (&str, u32) {
        match self {
            HiveHost::Local => ("localhost", 12356),
            HiveHost::Cloud => ("", 0),
            HiveHost::Custom { host, port } => (host, *port),
        }
    }
}

fn poll_simple<T: Zeroable>(
    handle: &mut Result<FutureHandle, Error>,
    poll: fn(FutureHandle) -> FFIResult<ffi::PollSimple>,
    take: fn(FutureHandle, &mut T) -> FFIResult<()>,
) -> Poll<Result<T, Error>> {
    let mut inner = || -> Result<Poll<Result<T, Error>>, Error> {
        let raw_handle = (*handle)?;
        let poll = poll(raw_handle)?;
        if poll.ready {
            let mut data = Zeroable::zeroed();

            // Replace the handle so that we will not try to drop the future later. The `take` function consumes the handle.
            *handle = Err(Error::NotFound);

            take(raw_handle, &mut data)?;

            Ok(Poll::Ready(Ok(data)))
        } else {
            Ok(Poll::Pending)
        }
    };

    match inner() {
        Ok(poll) => poll,
        Err(err) => Poll::Ready(Err(err)),
    }
}

fn poll_vec<T: Zeroable>(
    handle: &mut Result<FutureHandle, Error>,
    poll: fn(FutureHandle) -> FFIResult<ffi::PollVec>,
    take: fn(FutureHandle, &mut [T]) -> FFIResult<()>,
) -> Poll<Result<Vec<T>, Error>> {
    let mut inner = || -> Result<Poll<Result<Vec<T>, Error>>, Error> {
        let raw_handle = (*handle)?;
        let poll = poll(raw_handle)?;
        if poll.ready {
            let mut data = bytemuck::allocation::zeroed_slice_box(poll.len as usize).into_vec();

            // Replace the handle so that we will not try to drop the future later. The `take` function consumes the handle.
            *handle = Err(Error::NotFound);

            take(raw_handle, &mut data)?;

            Ok(Poll::Ready(Ok(data)))
        } else {
            Ok(Poll::Pending)
        }
    };

    match inner() {
        Ok(poll) => poll,
        Err(err) => Poll::Ready(Err(err)),
    }
}

fn poll_string(handle: &mut Result<FutureHandle, Error>) -> Poll<Result<String, Error>> {
    let mut inner = || -> Result<Poll<Result<String, Error>>, Error> {
        let raw_handle = (*handle)?;
        let poll = ffi::v4::poll_future_string(raw_handle)?;
        if poll.ready {
            // Replace the handle so that we will not try to drop the future later. The `take` function consumes the handle.
            *handle = Err(Error::NotFound);

            let s = ffi::v4::take_future_string(raw_handle)?;

            Ok(Poll::Ready(Ok(s)))
        } else {
            Ok(Poll::Pending)
        }
    };

    match inner() {
        Ok(poll) => poll,
        Err(err) => Poll::Ready(Err(err)),
    }
}

impl Future for MLFuture<bool> {
    type Output = Result<bool, Error>;

    fn poll(
        mut self: std::pin::Pin<&mut Self>,
        _cx: &mut std::task::Context<'_>,
    ) -> Poll<Self::Output> {
        poll_simple(
            &mut self.handle,
            ffi::poll_future_bool,
            ffi::take_future_bool,
        )
    }
}

impl Future for MLFuture<String> {
    type Output = Result<String, Error>;

    fn poll(
        mut self: std::pin::Pin<&mut Self>,
        _cx: &mut std::task::Context<'_>,
    ) -> Poll<Self::Output> {
        poll_string(&mut self.handle)
    }
}

impl Future for MLFuture<TrainingHandle> {
    type Output = Result<TrainingHandle, Error>;

    fn poll(
        mut self: std::pin::Pin<&mut Self>,
        _cx: &mut std::task::Context<'_>,
    ) -> Poll<Self::Output> {
        poll_simple(
            &mut self.handle,
            ffi::poll_future_training_handle,
            ffi::take_future_training_handle,
        )
    }
}

impl Future for MLFuture<Vec<u8>> {
    type Output = Result<Vec<u8>, Error>;

    fn poll(
        mut self: std::pin::Pin<&mut Self>,
        _cx: &mut std::task::Context<'_>,
    ) -> Poll<Self::Output> {
        poll_vec(
            &mut self.handle,
            ffi::poll_future_vec_u8,
            ffi::take_future_vec_u8,
        )
    }
}

impl Future for MLFuture<Vec<ffi::v3::Metric>> {
    type Output = Result<Vec<ffi::v3::Metric>, Error>;

    fn poll(
        mut self: std::pin::Pin<&mut Self>,
        _cx: &mut std::task::Context<'_>,
    ) -> Poll<Self::Output> {
        poll_vec(
            &mut self.handle,
            ffi::v3::poll_future_vec_metric,
            ffi::v3::take_future_vec_metric,
        )
    }
}

impl Future for MLFuture<Vec<ffi::ExperimentInfo>> {
    type Output = Result<Vec<ffi::ExperimentInfo>, Error>;

    fn poll(
        mut self: std::pin::Pin<&mut Self>,
        _cx: &mut std::task::Context<'_>,
    ) -> Poll<Self::Output> {
        poll_vec(
            &mut self.handle,
            ffi::poll_future_vec_experiment_info,
            ffi::take_future_vec_experiment_info,
        )
    }
}

impl Future for MLFuture<ffi::ArkExperimentConfig> {
    type Output = Result<ffi::ArkExperimentConfig, Error>;

    fn poll(
        mut self: std::pin::Pin<&mut Self>,
        _cx: &mut std::task::Context<'_>,
    ) -> Poll<Self::Output> {
        poll_simple(
            &mut self.handle,
            ffi::poll_future_ark_experiment_config,
            ffi::take_future_ark_experiment_config,
        )
    }
}

// Only a single drop implementation because Drop cannot be specialized (for some internal compiler reasons).
impl<T> Drop for MLFuture<T> {
    fn drop(&mut self) {
        if let Ok(handle) = &self.handle {
            ffi::drop_future(*handle).expect("Failed to drop future");
        }
    }
}