#![allow(missing_docs)]
mod ffi {
pub use crate::ffi::{ml_v1 as v1, ml_v1::*};
pub use crate::ffi::{ml_v2 as v2, ml_v2::*};
pub use crate::ffi::{ml_v3 as v3, ml_v3::*};
pub use crate::ffi::{ml_v4 as v4, ml_v4::*};
pub use crate::ffi::{ml_v5 as v5, ml_v5::*};
}
use crate::ffi::{ErrorCode, FFIResult};
use crate::Error;
use bytemuck::Zeroable;
pub use ffi::{
EpisodeState, ExperimentStatus, FixedBufferString, FutureHandle, HardwareType, InferenceHandle,
SnapshotFormat, TrainingHandle,
};
use std::time::Duration;
use std::{future::Future, marker::PhantomData, task::Poll};
#[doc(hidden)]
pub use ffi::v4::API as FFI_API;
#[cfg_attr(feature = "with_serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "with_speedy", derive(speedy::Writable, speedy::Readable))]
#[derive(Clone)]
pub enum Response {
EndOfEpisode,
Actions(Actions),
}
#[cfg_attr(feature = "with_serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "with_speedy", derive(speedy::Writable, speedy::Readable))]
#[derive(Clone)]
pub struct Actions {
pub actions: Vec<f32>,
pub value: f32,
}
#[cfg_attr(feature = "with_serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "with_speedy", derive(speedy::Writable, speedy::Readable))]
#[derive(Clone, Debug)]
pub struct Metric {
pub name: String,
pub value: f32,
}
#[cfg_attr(feature = "with_serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "with_speedy", derive(speedy::Writable, speedy::Readable))]
#[derive(Clone, Debug)]
pub struct ExperimentInfo {
pub id: String,
pub name: String,
#[cfg(feature = "time")]
pub started_at: crate::time::Instant,
#[cfg(feature = "time")]
pub ended_at: Option<crate::time::Instant>,
pub max_duration: u64,
pub experiment_status: ExperimentStatus,
pub worker_count: u32,
}
#[derive(Clone, Debug)]
pub struct CloudWorkerSettings {
pub module_id: String,
pub worker_count: u8,
}
#[derive(Clone, Debug)]
pub struct StartExperimentSettings {
pub trial_name: String,
pub experiment_name: String,
pub configuration: String,
pub duration: Duration,
pub hardware: HardwareType,
pub cloud_worker: Option<CloudWorkerSettings>,
}
#[derive(Clone, PartialEq, Debug)]
pub struct ProtocolConfig {
pub feature_count: u32,
pub action_count: u32,
pub hidden: u32,
pub alpha: f32,
pub use_terminal_masking: bool,
pub learning_rate_init: f32,
pub learning_rate_end: f32,
pub learning_rate_steps: u32,
pub batch_size: u32,
pub memory_min_size: u32,
pub memory_max_size: u32,
pub gamma: f32,
pub rollout_length: u32,
}
impl Default for ProtocolConfig {
fn default() -> Self {
Self {
feature_count: 0,
action_count: 0,
hidden: 1024,
alpha: 0.01,
use_terminal_masking: true,
learning_rate_init: 1e-2,
learning_rate_end: 1e-4,
learning_rate_steps: 3e5 as u32,
batch_size: 4000,
memory_min_size: 5000,
memory_max_size: 1e6 as u32,
gamma: 0.97,
rollout_length: 5,
}
}
}
impl ProtocolConfig {
fn to_ffi(&self) -> ffi::ProtocolConfig {
ffi::ProtocolConfig {
feature_count: self.feature_count,
action_count: self.action_count,
hidden: self.hidden,
alpha: self.alpha,
use_terminal_masking: u32::from(self.use_terminal_masking),
learning_rate_init: self.learning_rate_init,
learning_rate_end: self.learning_rate_end,
learning_rate_steps: self.learning_rate_steps,
batch_size: self.batch_size,
memory_min_size: self.memory_min_size,
memory_max_size: self.memory_max_size,
gamma: self.gamma,
rollout_length: self.rollout_length,
}
}
}
impl Metric {
pub fn convert_to_ffi_safe(&self) -> Result<ffi::v3::Metric, FixedBufferFromStringError> {
Ok(ffi::v3::Metric {
name: convert_string_to_fixed_buffer(&self.name)?,
value: self.value,
})
}
}
#[derive(Debug, Copy, Clone)]
pub enum FixedBufferFromStringError {
LengthTooLarge { found: usize, expected_max: usize },
}
impl std::error::Error for FixedBufferFromStringError {}
impl std::fmt::Display for FixedBufferFromStringError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FixedBufferFromStringError::LengthTooLarge {
found,
expected_max,
} => {
write!(f, "The string has a length which is larger than what will fit in the buffer. Found string of length {found} but can only fit strings of length at most {expected_max}" )
}
}
}
}
fn convert_string_to_fixed_buffer<const N: usize>(
string: &str,
) -> Result<ffi::FixedBufferString<N>, FixedBufferFromStringError> {
if string.len() > N {
Err(FixedBufferFromStringError::LengthTooLarge {
found: string.len(),
expected_max: N,
})
} else {
let mut bytes: [u8; N] = [0; N];
let str_data = string.as_bytes();
#[allow(clippy::indexing_slicing)]
bytes[..str_data.len()].copy_from_slice(str_data);
Ok(ffi::FixedBufferString {
length: str_data.len() as u32,
bytes,
})
}
}
#[derive(Clone, Debug)]
struct InferenceResponse {
_id: u64,
actions: Option<Vec<f32>>,
}
fn decode_inference_results(data: Vec<u8>) -> Result<Vec<InferenceResponse>, Error> {
let mut decoded = vec![];
let mut offset = 0;
let version = data[0]; if version != 0 {
return Err(Error::ApiNotAvailable); }
let response_count = data[1]; let action_count = u16::from_ne_bytes(data[2..4].try_into().map_err(|_err| Error::Internal)?);
offset += 4;
for _ in 0..response_count {
let id = u64::from_ne_bytes(
data[offset..offset + 8]
.try_into()
.map_err(|_err| Error::Internal)?,
);
offset += 8;
let has_data = data[offset];
offset += 1;
let actions = if has_data != 0 {
let action_bytes = (action_count * 4) as usize;
let action_u8s = &data[offset..offset + action_bytes];
let actions = action_u8s
.chunks_exact(4)
.map(|bytes| {
Ok(f32::from_ne_bytes(
bytes.try_into().map_err(|_err| Error::Internal)?,
))
})
.collect::<Result<Vec<f32>, Error>>()?;
offset += action_bytes;
Some(actions)
} else {
None
};
decoded.push(InferenceResponse { _id: id, actions });
}
Ok(decoded)
}
#[derive(Debug, Clone)]
pub enum FixedBufferToStringError {
LengthTooLarge { found: u32, expected_max: usize },
InvalidUtf8(std::string::FromUtf8Error),
}
impl std::error::Error for FixedBufferToStringError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
FixedBufferToStringError::InvalidUtf8(e) => Some(e),
FixedBufferToStringError::LengthTooLarge { .. } => None,
}
}
}
impl std::fmt::Display for FixedBufferToStringError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FixedBufferToStringError::LengthTooLarge {
found,
expected_max,
} => {
write!(f, "The buffer has a length which is larger than the actual buffer. Found length {found} but expected a length of at most {expected_max}")
}
FixedBufferToStringError::InvalidUtf8(e) => {
write!(f, "The string did not contain valid utf8 data: {e}")
}
}
}
}
fn convert_fixed_buffer_to_string<const N: usize>(
buffer: &FixedBufferString<N>,
) -> Result<String, FixedBufferToStringError> {
if buffer.length as usize > N {
Err(FixedBufferToStringError::LengthTooLarge {
found: buffer.length,
expected_max: N,
})
} else {
String::from_utf8(buffer.bytes[0..buffer.length as usize].to_vec())
.map_err(FixedBufferToStringError::InvalidUtf8)
}
}
pub struct Experiment {
pub(crate) ctx: TrainingHandle,
}
pub struct ArkExperimentConfig {
pub num_features: u32,
pub num_actions: u32,
}
pub struct ExperimentConfig {
pub ark_config: ArkExperimentConfig,
pub module_config: String,
}
impl ExperimentConfig {
pub(crate) async fn from_ffi_async(
future: Result<ffi::ExperimentConfigFuture, ErrorCode>,
) -> Result<Self, Error> {
let config = future.map_err(|_err| Error::Internal)?;
let ark_config = MLFuture::<ffi::ArkExperimentConfig>::new(Ok(config.ark_config))
.await
.map_err(|_err| Error::Internal)?;
let module_config_bytes = MLFuture::<Vec<u8>>::new(Ok(config.module_config))
.await
.map_err(|_err| Error::Internal)?;
let module_config =
String::from_utf8(module_config_bytes).map_err(|_err| Error::Internal)?;
Ok(ExperimentConfig {
ark_config: ArkExperimentConfig {
num_features: ark_config.num_features,
num_actions: ark_config.num_actions,
},
module_config,
})
}
}
#[inline]
pub fn onnx_to_cervo(buffer: &[u8]) -> Result<Vec<u8>, Error> {
ffi::onnx_to_cervo(buffer).map_err(Error::from)
}
#[inline]
pub fn can_connect_to_hive(host: &HiveHost) -> impl Future<Output = Result<bool, Error>> {
let (hive_url, hive_port) = host.to_url_port();
MLFuture::<bool>::new(ffi::can_connect_to_hive(hive_url, hive_port))
}
#[inline]
pub fn set_worker_module_link(cid: &str) -> Result<(), Error> {
ffi::set_worker_module_link(cid).map_err(Error::from)
}
pub fn experiment_config_from_registry(
host: &HiveHost,
run_id: &str,
) -> impl Future<Output = Result<ExperimentConfig, Error>> {
let (hive_url, hive_port) = host.to_url_port();
ExperimentConfig::from_ffi_async(ffi::experiment_config_from_registry(
hive_url, hive_port, run_id,
))
}
pub fn config_from_registry(
host: &HiveHost,
run_id: &str,
) -> impl Future<Output = Result<String, Error>> {
let (hive_url, hive_port) = host.to_url_port();
let fut = ffi::v4::raw_experiment_config_from_registry(hive_url, hive_port, run_id);
async move {
MLFuture::<String>::new(fut)
.await
.map_err(|_err| Error::Internal)
}
}
pub fn snapshot_from_registry(
host: &HiveHost,
run_id: &str,
) -> impl Future<Output = Result<Vec<u8>, Error>> {
let (hive_url, hive_port) = host.to_url_port();
MLFuture::<Vec<u8>>::new(ffi::download_snapshot_from_registry(
hive_url, hive_port, run_id,
))
}
pub fn list_experiments(
host: &HiveHost,
) -> impl Future<Output = Result<Vec<ExperimentInfo>, Error>> {
let (hive_url, hive_port) = host.to_url_port();
let future =
MLFuture::<Vec<ffi::ExperimentInfo>>::new(ffi::list_experiments(hive_url, hive_port));
async move {
let experiments = future.await?;
Ok(experiments
.iter()
.map(|experiment| ExperimentInfo {
id: convert_fixed_buffer_to_string(&experiment.id).unwrap(),
name: convert_fixed_buffer_to_string(&experiment.name).unwrap(),
#[cfg(feature = "time")]
started_at: crate::api::time::Instant::from_nanos_since_epoch(
experiment.started_at,
),
#[cfg(feature = "time")]
ended_at: if experiment.ended_at != 0 {
Some(crate::api::time::Instant::from_nanos_since_epoch(
experiment.ended_at,
))
} else {
None
},
max_duration: experiment.max_duration,
experiment_status: experiment.experiment_status,
worker_count: experiment.worker_count,
})
.collect())
}
}
#[allow(clippy::upper_case_acronyms)]
struct MLFuture<T> {
handle: Result<FutureHandle, Error>,
_phantom: PhantomData<T>,
}
impl<T> MLFuture<T> {
fn new(handle: Result<FutureHandle, ErrorCode>) -> Self {
Self {
handle: handle.map_err(Error::from),
_phantom: Default::default(),
}
}
}
impl Experiment {
#[allow(clippy::too_many_arguments)]
pub fn new(
host: &HiveHost,
game_name: &str,
experiment_name: &str,
num_remote_workers: u32,
config: &str,
checkpoint: Option<&str>,
training_duration_in_seconds: u64,
protocol: &ProtocolConfig,
) -> impl Future<Output = Result<Self, Error>> {
let (hive_url, hive_port) = host.to_url_port();
let handle = ffi::v4::start_training(
hive_url,
hive_port,
game_name,
experiment_name,
num_remote_workers,
config,
checkpoint.unwrap_or(""),
training_duration_in_seconds,
&protocol.to_ffi(),
);
let future = MLFuture::<TrainingHandle>::new(handle);
async move { Ok(Experiment { ctx: future.await? }) }
}
#[allow(clippy::too_many_arguments)]
pub async fn new_from_settings(
host: HiveHost,
settings: StartExperimentSettings,
) -> Result<Self, Error> {
let (hive_url, hive_port) = host.to_url_port();
let (module_id, worker_count) = settings
.cloud_worker
.as_ref()
.map_or(("", 0), |s| (s.module_id.as_str(), s.worker_count));
let handle = ffi::v5::start_training(
hive_url,
hive_port,
&settings.trial_name,
&settings.configuration,
&settings.experiment_name,
module_id,
settings.duration.as_secs() as u32,
u32::from(worker_count),
settings.hardware,
);
let ctx = MLFuture::<TrainingHandle>::new(handle).await?;
Ok(Experiment { ctx })
}
pub fn connect_to_experiment(
host: &HiveHost,
run_id: &str,
) -> impl Future<Output = Result<Self, Error>> {
let (hive_url, hive_port) = host.to_url_port();
let handle = ffi::connect_to_experiment(hive_url, hive_port, run_id);
let future = MLFuture::<TrainingHandle>::new(handle);
async move {
let ctx = future.await?;
Ok(Experiment { ctx })
}
}
pub fn config(&self) -> impl Future<Output = Result<ExperimentConfig, Error>> {
ExperimentConfig::from_ffi_async(ffi::experiment_config(self.ctx))
}
pub fn raw_config(&self) -> impl Future<Output = Result<String, Error>> {
let fut = ffi::v5::raw_experiment_config(self.ctx);
async move {
MLFuture::<String>::new(fut)
.await
.map_err(|_err| Error::Internal)
}
}
pub fn metrics(&self) -> impl Future<Output = Result<Vec<(String, f32)>, Error>> {
let result = MLFuture::<Vec<ffi::v3::Metric>>::new(ffi::v3::download_metrics(self.ctx));
async move {
match result.await {
Ok(metrics) => Ok(metrics
.into_iter()
.map(|metric| {
let name = convert_fixed_buffer_to_string(&metric.name)?;
Ok((name, metric.value))
})
.collect::<Result<Vec<_>, FixedBufferToStringError>>()
.map_err(|_e| Error::InvalidArguments)?),
Err(err) => Err(err),
}
}
}
pub fn snapshot(&self) -> impl Future<Output = Result<Vec<u8>, Error>> {
MLFuture::<Vec<u8>>::new(ffi::download_snapshot(self.ctx))
}
pub fn push_training_experiences(
&self,
observations: &[Observation<'_>],
) -> Result<Vec<Response>, Error> {
for observation in observations {
ffi::v3::push_training_observation(
self.ctx,
observation.id,
observation.episode_state as u32,
observation.reward,
observation.features,
&[],
&(observation
.metadata
.iter()
.map(Metric::convert_to_ffi_safe)
.collect::<Result<Vec<_>, _>>()
.map_err(|_e| Error::InvalidArguments)?),
)
.map_err(Error::from)?;
}
let response_bytes =
ffi::v5::submit_training_observations(self.ctx).map_err(Error::from)?;
let actions = decode_inference_results(response_bytes)?;
let mut responses = vec![Response::EndOfEpisode; observations.len()];
for ((observation, response), inference_result) in
observations.iter().zip(&mut responses).zip(actions)
{
if matches!(
observation.episode_state,
EpisodeState::Initial | EpisodeState::Running
) {
*response = Response::Actions(Actions {
actions: inference_result.actions.ok_or(Error::Internal)?,
value: 0.0,
});
}
}
Ok(responses)
}
pub fn push_demonstration_experiences(
&self,
demonstrations: &[Demonstration<'_>],
) -> Result<(), Error> {
for demonstration in demonstrations {
ffi::v3::push_training_observation(
self.ctx,
demonstration.observation.id,
demonstration.observation.episode_state as u32,
demonstration.observation.reward,
demonstration.observation.features,
demonstration.actions,
&metadata_to_ffi(demonstration.observation.metadata)?,
)
.map_err(Error::from)?;
}
ffi::submit_training_demonstration_observations(self.ctx).map_err(Error::from)
}
pub fn push_augmented_experiences(
&self,
observations: &[AugmentedObservation<'_>],
) -> Result<(), Error> {
for observation in observations {
ffi::v3::push_training_observation(
self.ctx,
observation.observation.id,
observation.observation.episode_state as u32,
observation.observation.reward,
observation.observation.features,
observation.actions,
&metadata_to_ffi(observation.observation.metadata)?,
)
.map_err(Error::from)?;
}
ffi::submit_training_augmented_observations(self.ctx).map_err(Error::from)
}
pub fn stop_experiment(self) -> Result<(), Error> {
ffi::stop_experiment(self.ctx).map_err(Error::from)
}
}
impl Drop for Experiment {
fn drop(&mut self) {
let _ = ffi::stop_training(self.ctx);
}
}
pub struct Inference {
pub(crate) ctx: InferenceHandle,
}
fn metadata_to_ffi(metadata: &[Metric]) -> Result<Vec<ffi::v3::Metric>, Error> {
let res = metadata
.iter()
.map(Metric::convert_to_ffi_safe)
.collect::<Result<Vec<_>, _>>()
.map_err(|_e| Error::InvalidArguments)?;
Ok(res)
}
impl Inference {
pub fn new(snapshot_data: &[u8], snapshot_format: SnapshotFormat) -> Result<Self, Error> {
ffi::v5::start_inference(snapshot_data, snapshot_format as u32)
.map(|ctx| Self { ctx })
.map_err(Error::from)
}
pub fn evaluate(&self, observations: &[Observation<'_>]) -> Result<Vec<Response>, Error> {
for observation in observations {
ffi::v3::push_inference_observation(
self.ctx,
observation.id,
observation.episode_state as u32,
observation.reward,
observation.features,
&metadata_to_ffi(observation.metadata)?,
)
.map_err(Error::from)?;
}
let response_bytes =
ffi::v5::submit_inference_observations(self.ctx).map_err(Error::from)?;
let actions = decode_inference_results(response_bytes)?;
let mut responses = vec![Response::EndOfEpisode; observations.len()];
for ((observation, response), actions) in
observations.iter().zip(&mut responses).zip(actions)
{
if matches!(
observation.episode_state,
EpisodeState::Initial | EpisodeState::Running
) {
*response = Response::Actions(Actions {
actions: actions.actions.unwrap(),
value: 0.0,
});
}
}
Ok(responses)
}
}
impl Drop for Inference {
fn drop(&mut self) {
if let Err(err) = ffi::stop_inference(self.ctx) {
log::error!("{:?}", err);
}
}
}
#[derive(Clone)]
pub struct Observation<'a> {
pub id: u64,
pub episode_state: EpisodeState,
pub reward: f32,
pub features: &'a [f32],
pub metadata: &'a [Metric],
}
pub struct Demonstration<'a> {
pub timestep: u64,
pub actions: &'a [f32],
pub observation: Observation<'a>,
}
pub struct AugmentedObservation<'a> {
pub observation: Observation<'a>,
pub actions: &'a [f32],
}
#[cfg_attr(feature = "with_serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "with_speedy", derive(speedy::Writable, speedy::Readable))]
#[derive(Clone, Debug, PartialEq)]
pub enum HiveHost {
Local,
Cloud,
Custom { host: String, port: u32 },
}
impl HiveHost {
pub(crate) fn to_url_port(&self) -> (&str, u32) {
match self {
HiveHost::Local => ("localhost", 12356),
HiveHost::Cloud => ("", 0),
HiveHost::Custom { host, port } => (host, *port),
}
}
}
fn poll_simple<T: Zeroable>(
handle: &mut Result<FutureHandle, Error>,
poll: fn(FutureHandle) -> FFIResult<ffi::PollSimple>,
take: fn(FutureHandle, &mut T) -> FFIResult<()>,
) -> Poll<Result<T, Error>> {
let mut inner = || -> Result<Poll<Result<T, Error>>, Error> {
let raw_handle = (*handle)?;
let poll = poll(raw_handle)?;
if poll.ready {
let mut data = Zeroable::zeroed();
*handle = Err(Error::NotFound);
take(raw_handle, &mut data)?;
Ok(Poll::Ready(Ok(data)))
} else {
Ok(Poll::Pending)
}
};
match inner() {
Ok(poll) => poll,
Err(err) => Poll::Ready(Err(err)),
}
}
fn poll_vec<T: Zeroable>(
handle: &mut Result<FutureHandle, Error>,
poll: fn(FutureHandle) -> FFIResult<ffi::PollVec>,
take: fn(FutureHandle, &mut [T]) -> FFIResult<()>,
) -> Poll<Result<Vec<T>, Error>> {
let mut inner = || -> Result<Poll<Result<Vec<T>, Error>>, Error> {
let raw_handle = (*handle)?;
let poll = poll(raw_handle)?;
if poll.ready {
let mut data = bytemuck::allocation::zeroed_slice_box(poll.len as usize).into_vec();
*handle = Err(Error::NotFound);
take(raw_handle, &mut data)?;
Ok(Poll::Ready(Ok(data)))
} else {
Ok(Poll::Pending)
}
};
match inner() {
Ok(poll) => poll,
Err(err) => Poll::Ready(Err(err)),
}
}
fn poll_string(handle: &mut Result<FutureHandle, Error>) -> Poll<Result<String, Error>> {
let mut inner = || -> Result<Poll<Result<String, Error>>, Error> {
let raw_handle = (*handle)?;
let poll = ffi::v4::poll_future_string(raw_handle)?;
if poll.ready {
*handle = Err(Error::NotFound);
let s = ffi::v4::take_future_string(raw_handle)?;
Ok(Poll::Ready(Ok(s)))
} else {
Ok(Poll::Pending)
}
};
match inner() {
Ok(poll) => poll,
Err(err) => Poll::Ready(Err(err)),
}
}
impl Future for MLFuture<bool> {
type Output = Result<bool, Error>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
poll_simple(
&mut self.handle,
ffi::poll_future_bool,
ffi::take_future_bool,
)
}
}
impl Future for MLFuture<String> {
type Output = Result<String, Error>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
poll_string(&mut self.handle)
}
}
impl Future for MLFuture<TrainingHandle> {
type Output = Result<TrainingHandle, Error>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
poll_simple(
&mut self.handle,
ffi::poll_future_training_handle,
ffi::take_future_training_handle,
)
}
}
impl Future for MLFuture<Vec<u8>> {
type Output = Result<Vec<u8>, Error>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
poll_vec(
&mut self.handle,
ffi::poll_future_vec_u8,
ffi::take_future_vec_u8,
)
}
}
impl Future for MLFuture<Vec<ffi::v3::Metric>> {
type Output = Result<Vec<ffi::v3::Metric>, Error>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
poll_vec(
&mut self.handle,
ffi::v3::poll_future_vec_metric,
ffi::v3::take_future_vec_metric,
)
}
}
impl Future for MLFuture<Vec<ffi::ExperimentInfo>> {
type Output = Result<Vec<ffi::ExperimentInfo>, Error>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
poll_vec(
&mut self.handle,
ffi::poll_future_vec_experiment_info,
ffi::take_future_vec_experiment_info,
)
}
}
impl Future for MLFuture<ffi::ArkExperimentConfig> {
type Output = Result<ffi::ArkExperimentConfig, Error>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
poll_simple(
&mut self.handle,
ffi::poll_future_ark_experiment_config,
ffi::take_future_ark_experiment_config,
)
}
}
impl<T> Drop for MLFuture<T> {
fn drop(&mut self) {
if let Ok(handle) = &self.handle {
ffi::drop_future(*handle).expect("Failed to drop future");
}
}
}