1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
define_api_id!(0xaed6_0109_4f07_3c82, "ml-v3");
pub use super::ml_v1::FixedBufferString;
pub use super::ml_v1::FutureHandle;
pub use super::ml_v1::InferenceHandle;
pub use super::ml_v1::PollVec;
pub use super::ml_v1::TrainingHandle;
use crate::FFIResult;
use bytemuck::Pod;
use bytemuck::Zeroable;
#[repr(C)]
#[derive(Copy, Clone, Debug, Pod, Zeroable)]
pub struct Metric {
pub name: FixedBufferString<64>,
pub value: f32,
}
#[allow(clippy::too_many_arguments)]
#[ark_api_macros::ark_bindgen(imports = "ark-ml-v3")]
mod ml {
use super::*;
extern "C" {
/// Starts training.
///
/// The promise outputs a `TrainingHandle`.
///
/// Will start from a checkpoint if the checkpoint parameter is non-empty.
pub fn start_training(
hive_url: &str,
hive_port: u32,
game_name: &str,
experiment_name: &str,
num_features: u32,
num_actions: u32,
num_remote_workers: u32,
config: &str,
checkpoint: &str,
training_duration_in_seconds: u64,
) -> FFIResult<FutureHandle>;
/// Pushes an observation with metadata into a queue to be used later.
///
/// The `actions` parameter should only be set for augmented observations and demonstration observations.
///
/// See `submit_training_observations`, `submit_training_augmented_observations` and `submit_training_demonstration_observations`.
pub fn push_training_observation(
context: TrainingHandle,
id: u64,
// Actually a ml_v1::EpisodeState, but due to ffi limitations we cannot pass that here
episode_state: u32,
reward: f32,
features: &[f32],
actions: &[f32],
metadata: &[Metric],
) -> FFIResult<()>;
/// Pushes an observation with metadata into a queue to be used later.
pub fn push_inference_observation(
context: InferenceHandle,
id: u64,
// Actually a ml_v1::EpisodeState, but due to ffi limitations we cannot pass that here
episode_state: u32,
reward: f32,
features: &[f32],
metadata: &[Metric],
) -> FFIResult<()>;
/// Downloads metrics asynchronously
///
/// The promise outputs a `Vec<Metric>`.
pub fn download_metrics(context: TrainingHandle) -> FFIResult<FutureHandle>;
pub fn poll_future_vec_metric(handle: FutureHandle) -> FFIResult<PollVec>;
pub fn take_future_vec_metric(handle: FutureHandle, data: &mut [Metric]) -> FFIResult<()>;
}
}
pub use ml::*;