1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
define_api_id!(0xaed6_0109_4f07_3c80, "ml-v1");

use crate::FFIResult;
use bytemuck::CheckedBitPattern;
use bytemuck::NoUninit;
use bytemuck::Pod;
use bytemuck::Zeroable;

pub type TrainingHandle = u64;
pub type InferenceHandle = u64;

/// A handle to a future on the engine-side.
///
/// The handle is untyped because the ffi proc macros do not support generics where we need them.
/// Therefore it is the apis job to keep track of what type the future will actually return and call the corresponding `poll` and `take` methods.
/// See e.g. `ml::take_future_bool`.
pub type FutureHandle = u64;

/// A variable length string that can be up to N bytes long.
#[repr(C)]
#[derive(Clone, Copy, Debug)]
pub struct FixedBufferString<const N: usize> {
    /// Length of the string in bytes
    pub length: u32,
    /// Contents of the string.
    /// Only the first `length` bytes are part of the string.
    /// The remaining bytes are undefined.
    pub bytes: [u8; N],
}

// SAFETY: Safe because all fields are Pod, there's no padding, and the generic does not change that. Can't
// be derived due to limitations of derive macro with generics.
unsafe impl<const N: usize> Pod for FixedBufferString<N> {}

// SAFETY: Safe because all fields are Zeroable and the generic does not change that. Can't
// be derived due to limitations of derive macro with generics.
unsafe impl<const N: usize> Zeroable for FixedBufferString<N> {}

#[repr(C)]
#[derive(Copy, Clone, Debug, Pod, Zeroable)]
pub struct Metric {
    pub name: [u8; 28],
    pub value: f32,
}

/// Note: needs to be Copy due to wasmtime constraints.
#[repr(C)]
#[derive(Clone, Copy, Debug, NoUninit, CheckedBitPattern)]
pub struct ExperimentInfo {
    /// Human readable name of the experiment.
    ///
    /// Experiment names can be long.
    pub name: FixedBufferString<128>,
    /// Internal id of the experiment
    pub id: FixedBufferString<64>,
    /// Time when experiment was started as nanoseconds since the UNIX epoch.
    pub started_at: i64,
    /// Time when experiment was ended as nanoseconds since the UNIX epoch, or 0 if it has not ended yet.
    pub ended_at: i64,
    /// The experiment will be automatically stopped after this number of seconds
    pub max_duration: u64,
    /// Status of the experiment on the hive server.
    pub experiment_status: ExperimentStatus,
    /// Number of game workers that are started on the server. May be 0 for in-client training.
    pub worker_count: u32,
}

// SAFETY: This is *technically* not safe yet with how bytemuck currently defines Zeroable, but since
// we're only using it Zeroable to get `Zeroable::zeroed` (which *is* safe), we're fine. Zeroed memory is fine
// here because `ExperimentStatus` has 0 as a valid value.
unsafe impl Zeroable for ExperimentInfo {}

#[allow(clippy::too_many_arguments, deprecated)] // Deprecated to allow Serialize/Deserialize to touch Tensorflow variant
#[ark_api_macros::ark_bindgen(imports = "ark-ml-v1")]
mod ml {
    use super::*;

    /// Type used to describe a data object created using either `create_data` or by the host.
    #[cfg_attr(feature = "with_serde", derive(serde::Serialize, serde::Deserialize))]
    #[cfg_attr(feature = "with_speedy", derive(speedy::Writable, speedy::Readable))]
    #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
    #[repr(u32)]
    #[non_exhaustive]
    pub enum SnapshotFormat {
        ONNX = 0,

        #[deprecated(note = "Tensorflow format is no longer supported")]
        Tensorflow = 1,

        Cervo = 2,
    }

    #[cfg_attr(feature = "with_serde", derive(serde::Serialize, serde::Deserialize))]
    #[cfg_attr(feature = "with_speedy", derive(speedy::Writable, speedy::Readable))]
    #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
    #[repr(u32)]
    #[non_exhaustive]
    pub enum ExperimentStatus {
        Pending = 0,
        Starting = 1,
        Running = 2,
        Terminating = 3,
        Stopped = 4,
        NotFound = 5,
        Aborted = 6,
        Unhealthy = 7,
    }

    #[derive(Copy, Clone, Default, Debug, NoUninit, CheckedBitPattern)]
    #[repr(C)]
    /// Indicates if a future has been completed.
    /// Should be an enum, but kept as a struct for symmetry with PollVec
    pub struct PollSimple {
        pub ready: bool,
    }

    #[derive(Copy, Clone, Default, Debug, NoUninit, CheckedBitPattern)]
    #[repr(C)]
    /// Indicates if a future that outputs a `Vec<T>` has been completed.
    /// Should be an enum, but the ffi macros don't support that
    pub struct PollVec {
        pub ready: bool,
        pub _pad: [u8; 3],
        /// The length of the output vector.
        pub len: u32,
    }

    #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
    #[repr(u32)]
    pub enum EpisodeState {
        /// Episode is running.
        Running = 0,
        /// The episode has been terminated due to something the agent did.
        ///
        /// The backend may use terminal masking for this state.
        /// In that case the agent will behave as if from this state it cannot receive any more rewards.
        ///
        /// This state is useful if the agent for example has died or in some other way ended up in a non-recoverable state.
        ///
        /// The agent will receive a [`ark_api::Response::EndOfEpisode`](https://ark.embark.dev/api/ark_api/ml/enum.Response.html#variant.EndOfEpisode) when evaluating an observation with this state.
        /// This agent is considered removed from the simulation and you may not send in an observation with the same id again.
        Terminal = 1,
        /// The episode has been terminated due to an external reason that the agent cannot control.
        ///
        /// This is useful for things like timeouts that aren't necessarily due to bad behavior of the agent.
        /// No terminal masking will be used in this case.
        ///
        /// The agent will receive a [`ark_api::Response::EndOfEpisode`](https://ark.embark.dev/api/ark_api/ml/enum.Response.html#variant.EndOfEpisode) when evaluating an observation with this state.
        /// This agent is considered removed from the simulation and you may not send in an observation with the same id again.
        Interrupted = 2,
        /// First frame of the episode.
        ///
        /// The first observation that an agent sends to the engine must always be marked with this episode state.
        Initial = 3,
    }

    extern "C" {
        /// Convert a raw ONNX asset to Cervo format
        ///
        /// Returns the serialized asset.
        pub fn onnx_to_cervo(buffer: &[u8]) -> FFIResult<Vec<u8>>;

        /// Ok(true) if we can connect to hive.
        ///
        /// The promise outputs a `bool`.
        pub fn can_connect_to_hive(hive_url: &str, hive_port: u32) -> FFIResult<FutureHandle>;

        /// Lists all experiments on the machine learning server.
        ///
        /// The promise outputs a `Vec<ExperimentInfo>`.
        pub fn list_experiments(hive_url: &str, hive_port: u32) -> FFIResult<FutureHandle>;

        /// Number of actions that an experiment outputs
        pub fn experiment_action_count(context: TrainingHandle) -> FFIResult<u32>;

        /// Number of features that an experiment takes as input
        pub fn experiment_feature_count(context: TrainingHandle) -> FFIResult<u32>;

        /// Starts training.
        ///
        /// The promise outputs a `TrainingHandle`.
        pub fn start_training(
            hive_url: &str,
            hive_port: u32,
            game_name: &str,
            experiment_name: &str,
            num_features: u32,
            num_actions: u32,
        ) -> FFIResult<FutureHandle>;

        /// Starts training from a checkpoint.
        ///
        /// The promise outputs a `TrainingHandle`.
        pub fn start_training_from_checkpoint(
            hive_url: &str,
            hive_port: u32,
            game_name: &str,
            experiment_name: &str,
            num_features: u32,
            num_actions: u32,
            checkpoint: &str,
        ) -> FFIResult<FutureHandle>;

        /// Pushes an observation into a queue to be used later.
        ///
        /// The `actions` parameter should only be set for augmented observations and demonstration observations.
        ///
        /// See `submit_training_observations`, `submit_training_augmented_observations` and `submit_training_demonstration_observations`.
        pub fn push_training_observation(
            context: TrainingHandle,
            id: u64,
            episode_state: EpisodeState,
            reward: f32,
            features: &[f32],
            actions: &[f32],
            length: f32,
        ) -> FFIResult<()>;

        /// Pushes an observation with metadata into a queue to be used later.
        ///
        /// The `actions` parameter should only be set for augmented observations and demonstration observations.
        ///
        /// See `submit_training_observations`, `submit_training_augmented_observations` and `submit_training_demonstration_observations`.
        pub fn push_training_observation_with_metadata(
            context: TrainingHandle,
            id: u64,
            episode_state: EpisodeState,
            reward: f32,
            features: &[f32],
            actions: &[f32],
            metadata: &[Metric],
        ) -> FFIResult<()>;

        /// Submits observations previously pushed using `push_training_observation`.
        /// `out_actions` should be a slice with length `num_actions * observation count`.
        /// `out_values` should be a slice with length `observation count`.
        ///
        /// Each consecutive sequence of `num_actions` values in `out_actions` corresponds to the actions for an observation.
        /// If an observation was terminal or interrupted then all actions and the value for that observation will be zeroed out.
        pub fn submit_training_observations(
            context: TrainingHandle,
            out_actions: &mut [f32],
            out_values: &mut [f32],
        ) -> FFIResult<()>;

        /// Submits observations previously pushed using `push_training_observation` as augmented observations.
        pub fn submit_training_augmented_observations(context: TrainingHandle) -> FFIResult<()>;

        /// Submits observations previously pushed using `push_training_observation` as demonstration observations.
        pub fn submit_training_demonstration_observations(context: TrainingHandle)
            -> FFIResult<()>;

        /// Stops training with a given handle and destroys it.
        ///
        /// This will only disconnect the current client.
        /// If the experiment has remote workers they will continue to run.
        pub fn stop_training(context: TrainingHandle) -> FFIResult<()>;

        /// Retrieves a trained snapshot.
        ///
        /// The promise outputs a `Vec<u8>`.
        pub fn download_snapshot(context: TrainingHandle) -> FFIResult<FutureHandle>;

        /// Downloads metrics asynchronously
        ///
        /// The promise outputs a `Vec<Metric>`.
        pub fn download_metrics(context: TrainingHandle) -> FFIResult<FutureHandle>;

        pub fn start_inference(
            num_features: u32,
            num_actions: u32,
            snapshot_data: &[u8],
            snapshot_format: SnapshotFormat,
        ) -> FFIResult<InferenceHandle>;

        /// Submits observations previously pushed using `push_inference_observation`.
        /// `out_actions` should be a slice with length `num_actions * observation count`.
        /// `out_values` should be a slice with length `observation count`.
        ///
        /// Each consecutive sequence of `num_actions` values in `out_actions` corresponds to the actions for an observation.
        /// If an observation was terminal or interrupted then all actions and the value for that observation will be zeroed out.
        pub fn submit_inference_observations(
            context: InferenceHandle,
            out_actions: &mut [f32],
            out_values: &mut [f32],
        ) -> FFIResult<()>;

        /// Pushes an observation into a queue to be used later.
        pub fn push_inference_observation(
            context: InferenceHandle,
            id: u64,
            episode_state: EpisodeState,
            reward: f32,
            features: &[f32],
            length: f32,
        ) -> FFIResult<()>;

        /// Pushes an observation with metadata into a queue to be used later.
        pub fn push_inference_observation_with_metadata(
            context: InferenceHandle,
            id: u64,
            episode_state: EpisodeState,
            reward: f32,
            features: &[f32],
            metadata: &[Metric],
        ) -> FFIResult<()>;

        pub fn stop_inference(context: InferenceHandle) -> FFIResult<()>;

        /// Polls a future that outputs a boolean
        pub fn poll_future_bool(handle: FutureHandle) -> FFIResult<PollSimple>;

        /// Transfers the data of a completed future to the module side.
        ///
        /// If this method returns an error the `data` may not have been assigned.
        /// Calling this method consumes the handle and calling it again will return an error.
        pub fn take_future_bool(handle: FutureHandle, data: &mut bool) -> FFIResult<()>;

        pub fn poll_future_vec_u8(handle: FutureHandle) -> FFIResult<PollVec>;
        pub fn take_future_vec_u8(handle: FutureHandle, data: &mut [u8]) -> FFIResult<()>;

        pub fn poll_future_vec_metric(handle: FutureHandle) -> FFIResult<PollVec>;
        pub fn take_future_vec_metric(handle: FutureHandle, data: &mut [Metric]) -> FFIResult<()>;

        pub fn poll_future_training_handle(handle: FutureHandle) -> FFIResult<PollSimple>;
        pub fn take_future_training_handle(
            handle: FutureHandle,
            data: &mut TrainingHandle,
        ) -> FFIResult<()>;

        pub fn poll_future_vec_experiment_info(handle: FutureHandle) -> FFIResult<PollVec>;
        pub fn take_future_vec_experiment_info(
            handle: FutureHandle,
            data: &mut [ExperimentInfo],
        ) -> FFIResult<()>;

        /// Drops a future handle.
        ///
        /// This must be called if you do not call `take_*` to avoid a memory leak.
        pub fn drop_future(handle: FutureHandle) -> FFIResult<()>;
    }
}

pub use ml::*;