1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
define_api_id!(0xaed6_0109_4f07_3c80, "ml-v1");
use crate::FFIResult;
use bytemuck::CheckedBitPattern;
use bytemuck::NoUninit;
use bytemuck::Pod;
use bytemuck::Zeroable;
pub type TrainingHandle = u64;
pub type InferenceHandle = u64;
/// A handle to a future on the engine-side.
///
/// The handle is untyped because the ffi proc macros do not support generics where we need them.
/// Therefore it is the apis job to keep track of what type the future will actually return and call the corresponding `poll` and `take` methods.
/// See e.g. `ml::take_future_bool`.
pub type FutureHandle = u64;
/// A variable length string that can be up to N bytes long.
#[repr(C)]
#[derive(Clone, Copy, Debug)]
pub struct FixedBufferString<const N: usize> {
/// Length of the string in bytes
pub length: u32,
/// Contents of the string.
/// Only the first `length` bytes are part of the string.
/// The remaining bytes are undefined.
pub bytes: [u8; N],
}
// SAFETY: Safe because all fields are Pod, there's no padding, and the generic does not change that. Can't
// be derived due to limitations of derive macro with generics.
unsafe impl<const N: usize> Pod for FixedBufferString<N> {}
// SAFETY: Safe because all fields are Zeroable and the generic does not change that. Can't
// be derived due to limitations of derive macro with generics.
unsafe impl<const N: usize> Zeroable for FixedBufferString<N> {}
#[repr(C)]
#[derive(Copy, Clone, Debug, Pod, Zeroable)]
pub struct Metric {
pub name: [u8; 28],
pub value: f32,
}
/// Note: needs to be Copy due to wasmtime constraints.
#[repr(C)]
#[derive(Clone, Copy, Debug, NoUninit, CheckedBitPattern)]
pub struct ExperimentInfo {
/// Human readable name of the experiment.
///
/// Experiment names can be long.
pub name: FixedBufferString<128>,
/// Internal id of the experiment
pub id: FixedBufferString<64>,
/// Time when experiment was started as nanoseconds since the UNIX epoch.
pub started_at: i64,
/// Time when experiment was ended as nanoseconds since the UNIX epoch, or 0 if it has not ended yet.
pub ended_at: i64,
/// The experiment will be automatically stopped after this number of seconds
pub max_duration: u64,
/// Status of the experiment on the hive server.
pub experiment_status: ExperimentStatus,
/// Number of game workers that are started on the server. May be 0 for in-client training.
pub worker_count: u32,
}
// SAFETY: This is *technically* not safe yet with how bytemuck currently defines Zeroable, but since
// we're only using it Zeroable to get `Zeroable::zeroed` (which *is* safe), we're fine. Zeroed memory is fine
// here because `ExperimentStatus` has 0 as a valid value.
unsafe impl Zeroable for ExperimentInfo {}
#[allow(clippy::too_many_arguments, deprecated)] // Deprecated to allow Serialize/Deserialize to touch Tensorflow variant
#[ark_api_macros::ark_bindgen(imports = "ark-ml-v1")]
mod ml {
use super::*;
/// Type used to describe a data object created using either `create_data` or by the host.
#[cfg_attr(feature = "with_serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "with_speedy", derive(speedy::Writable, speedy::Readable))]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
#[repr(u32)]
#[non_exhaustive]
pub enum SnapshotFormat {
ONNX = 0,
#[deprecated(note = "Tensorflow format is no longer supported")]
Tensorflow = 1,
Cervo = 2,
}
#[cfg_attr(feature = "with_serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "with_speedy", derive(speedy::Writable, speedy::Readable))]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
#[repr(u32)]
#[non_exhaustive]
pub enum ExperimentStatus {
Pending = 0,
Starting = 1,
Running = 2,
Terminating = 3,
Stopped = 4,
NotFound = 5,
Aborted = 6,
Unhealthy = 7,
}
#[derive(Copy, Clone, Default, Debug, NoUninit, CheckedBitPattern)]
#[repr(C)]
/// Indicates if a future has been completed.
/// Should be an enum, but kept as a struct for symmetry with PollVec
pub struct PollSimple {
pub ready: bool,
}
#[derive(Copy, Clone, Default, Debug, NoUninit, CheckedBitPattern)]
#[repr(C)]
/// Indicates if a future that outputs a `Vec<T>` has been completed.
/// Should be an enum, but the ffi macros don't support that
pub struct PollVec {
pub ready: bool,
pub _pad: [u8; 3],
/// The length of the output vector.
pub len: u32,
}
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
#[repr(u32)]
pub enum EpisodeState {
/// Episode is running.
Running = 0,
/// The episode has been terminated due to something the agent did.
///
/// The backend may use terminal masking for this state.
/// In that case the agent will behave as if from this state it cannot receive any more rewards.
///
/// This state is useful if the agent for example has died or in some other way ended up in a non-recoverable state.
///
/// The agent will receive a [`ark_api::Response::EndOfEpisode`](https://ark.embark.dev/api/ark_api/ml/enum.Response.html#variant.EndOfEpisode) when evaluating an observation with this state.
/// This agent is considered removed from the simulation and you may not send in an observation with the same id again.
Terminal = 1,
/// The episode has been terminated due to an external reason that the agent cannot control.
///
/// This is useful for things like timeouts that aren't necessarily due to bad behavior of the agent.
/// No terminal masking will be used in this case.
///
/// The agent will receive a [`ark_api::Response::EndOfEpisode`](https://ark.embark.dev/api/ark_api/ml/enum.Response.html#variant.EndOfEpisode) when evaluating an observation with this state.
/// This agent is considered removed from the simulation and you may not send in an observation with the same id again.
Interrupted = 2,
/// First frame of the episode.
///
/// The first observation that an agent sends to the engine must always be marked with this episode state.
Initial = 3,
}
extern "C" {
/// Convert a raw ONNX asset to Cervo format
///
/// Returns the serialized asset.
pub fn onnx_to_cervo(buffer: &[u8]) -> FFIResult<Vec<u8>>;
/// Ok(true) if we can connect to hive.
///
/// The promise outputs a `bool`.
pub fn can_connect_to_hive(hive_url: &str, hive_port: u32) -> FFIResult<FutureHandle>;
/// Lists all experiments on the machine learning server.
///
/// The promise outputs a `Vec<ExperimentInfo>`.
pub fn list_experiments(hive_url: &str, hive_port: u32) -> FFIResult<FutureHandle>;
/// Number of actions that an experiment outputs
pub fn experiment_action_count(context: TrainingHandle) -> FFIResult<u32>;
/// Number of features that an experiment takes as input
pub fn experiment_feature_count(context: TrainingHandle) -> FFIResult<u32>;
/// Starts training.
///
/// The promise outputs a `TrainingHandle`.
pub fn start_training(
hive_url: &str,
hive_port: u32,
game_name: &str,
experiment_name: &str,
num_features: u32,
num_actions: u32,
) -> FFIResult<FutureHandle>;
/// Starts training from a checkpoint.
///
/// The promise outputs a `TrainingHandle`.
pub fn start_training_from_checkpoint(
hive_url: &str,
hive_port: u32,
game_name: &str,
experiment_name: &str,
num_features: u32,
num_actions: u32,
checkpoint: &str,
) -> FFIResult<FutureHandle>;
/// Pushes an observation into a queue to be used later.
///
/// The `actions` parameter should only be set for augmented observations and demonstration observations.
///
/// See `submit_training_observations`, `submit_training_augmented_observations` and `submit_training_demonstration_observations`.
pub fn push_training_observation(
context: TrainingHandle,
id: u64,
episode_state: EpisodeState,
reward: f32,
features: &[f32],
actions: &[f32],
length: f32,
) -> FFIResult<()>;
/// Pushes an observation with metadata into a queue to be used later.
///
/// The `actions` parameter should only be set for augmented observations and demonstration observations.
///
/// See `submit_training_observations`, `submit_training_augmented_observations` and `submit_training_demonstration_observations`.
pub fn push_training_observation_with_metadata(
context: TrainingHandle,
id: u64,
episode_state: EpisodeState,
reward: f32,
features: &[f32],
actions: &[f32],
metadata: &[Metric],
) -> FFIResult<()>;
/// Submits observations previously pushed using `push_training_observation`.
/// `out_actions` should be a slice with length `num_actions * observation count`.
/// `out_values` should be a slice with length `observation count`.
///
/// Each consecutive sequence of `num_actions` values in `out_actions` corresponds to the actions for an observation.
/// If an observation was terminal or interrupted then all actions and the value for that observation will be zeroed out.
pub fn submit_training_observations(
context: TrainingHandle,
out_actions: &mut [f32],
out_values: &mut [f32],
) -> FFIResult<()>;
/// Submits observations previously pushed using `push_training_observation` as augmented observations.
pub fn submit_training_augmented_observations(context: TrainingHandle) -> FFIResult<()>;
/// Submits observations previously pushed using `push_training_observation` as demonstration observations.
pub fn submit_training_demonstration_observations(context: TrainingHandle)
-> FFIResult<()>;
/// Stops training with a given handle and destroys it.
///
/// This will only disconnect the current client.
/// If the experiment has remote workers they will continue to run.
pub fn stop_training(context: TrainingHandle) -> FFIResult<()>;
/// Retrieves a trained snapshot.
///
/// The promise outputs a `Vec<u8>`.
pub fn download_snapshot(context: TrainingHandle) -> FFIResult<FutureHandle>;
/// Downloads metrics asynchronously
///
/// The promise outputs a `Vec<Metric>`.
pub fn download_metrics(context: TrainingHandle) -> FFIResult<FutureHandle>;
pub fn start_inference(
num_features: u32,
num_actions: u32,
snapshot_data: &[u8],
snapshot_format: SnapshotFormat,
) -> FFIResult<InferenceHandle>;
/// Submits observations previously pushed using `push_inference_observation`.
/// `out_actions` should be a slice with length `num_actions * observation count`.
/// `out_values` should be a slice with length `observation count`.
///
/// Each consecutive sequence of `num_actions` values in `out_actions` corresponds to the actions for an observation.
/// If an observation was terminal or interrupted then all actions and the value for that observation will be zeroed out.
pub fn submit_inference_observations(
context: InferenceHandle,
out_actions: &mut [f32],
out_values: &mut [f32],
) -> FFIResult<()>;
/// Pushes an observation into a queue to be used later.
pub fn push_inference_observation(
context: InferenceHandle,
id: u64,
episode_state: EpisodeState,
reward: f32,
features: &[f32],
length: f32,
) -> FFIResult<()>;
/// Pushes an observation with metadata into a queue to be used later.
pub fn push_inference_observation_with_metadata(
context: InferenceHandle,
id: u64,
episode_state: EpisodeState,
reward: f32,
features: &[f32],
metadata: &[Metric],
) -> FFIResult<()>;
pub fn stop_inference(context: InferenceHandle) -> FFIResult<()>;
/// Polls a future that outputs a boolean
pub fn poll_future_bool(handle: FutureHandle) -> FFIResult<PollSimple>;
/// Transfers the data of a completed future to the module side.
///
/// If this method returns an error the `data` may not have been assigned.
/// Calling this method consumes the handle and calling it again will return an error.
pub fn take_future_bool(handle: FutureHandle, data: &mut bool) -> FFIResult<()>;
pub fn poll_future_vec_u8(handle: FutureHandle) -> FFIResult<PollVec>;
pub fn take_future_vec_u8(handle: FutureHandle, data: &mut [u8]) -> FFIResult<()>;
pub fn poll_future_vec_metric(handle: FutureHandle) -> FFIResult<PollVec>;
pub fn take_future_vec_metric(handle: FutureHandle, data: &mut [Metric]) -> FFIResult<()>;
pub fn poll_future_training_handle(handle: FutureHandle) -> FFIResult<PollSimple>;
pub fn take_future_training_handle(
handle: FutureHandle,
data: &mut TrainingHandle,
) -> FFIResult<()>;
pub fn poll_future_vec_experiment_info(handle: FutureHandle) -> FFIResult<PollVec>;
pub fn take_future_vec_experiment_info(
handle: FutureHandle,
data: &mut [ExperimentInfo],
) -> FFIResult<()>;
/// Drops a future handle.
///
/// This must be called if you do not call `take_*` to avoid a memory leak.
pub fn drop_future(handle: FutureHandle) -> FFIResult<()>;
}
}
pub use ml::*;