ark_api_ffi/ffi/
ml_v3.rs

1define_api_id!(0xaed6_0109_4f07_3c82, "ml-v3");
2
3pub use super::ml_v1::FixedBufferString;
4pub use super::ml_v1::FutureHandle;
5pub use super::ml_v1::InferenceHandle;
6pub use super::ml_v1::PollVec;
7pub use super::ml_v1::TrainingHandle;
8use crate::FFIResult;
9use bytemuck::Pod;
10use bytemuck::Zeroable;
11
12#[repr(C)]
13#[derive(Copy, Clone, Debug, Pod, Zeroable)]
14pub struct Metric {
15    pub name: FixedBufferString<64>,
16    pub value: f32,
17}
18
19#[allow(clippy::too_many_arguments)]
20#[ark_api_macros::ark_bindgen(imports = "ark-ml-v3")]
21mod ml {
22    use super::*;
23
24    extern "C" {
25        /// Starts training.
26        ///
27        /// The promise outputs a `TrainingHandle`.
28        ///
29        /// Will start from a checkpoint if the checkpoint parameter is non-empty.
30        pub fn start_training(
31            hive_url: &str,
32            hive_port: u32,
33            game_name: &str,
34            experiment_name: &str,
35            num_features: u32,
36            num_actions: u32,
37            num_remote_workers: u32,
38            config: &str,
39            checkpoint: &str,
40            training_duration_in_seconds: u64,
41        ) -> FFIResult<FutureHandle>;
42
43        /// Pushes an observation with metadata into a queue to be used later.
44        ///
45        /// The `actions` parameter should only be set for augmented observations and demonstration observations.
46        ///
47        /// See `submit_training_observations`, `submit_training_augmented_observations` and `submit_training_demonstration_observations`.
48        pub fn push_training_observation(
49            context: TrainingHandle,
50            id: u64,
51            // Actually a ml_v1::EpisodeState, but due to ffi limitations we cannot pass that here
52            episode_state: u32,
53            reward: f32,
54            features: &[f32],
55            actions: &[f32],
56            metadata: &[Metric],
57        ) -> FFIResult<()>;
58
59        /// Pushes an observation with metadata into a queue to be used later.
60        pub fn push_inference_observation(
61            context: InferenceHandle,
62            id: u64,
63            // Actually a ml_v1::EpisodeState, but due to ffi limitations we cannot pass that here
64            episode_state: u32,
65            reward: f32,
66            features: &[f32],
67            metadata: &[Metric],
68        ) -> FFIResult<()>;
69
70        /// Downloads metrics asynchronously
71        ///
72        /// The promise outputs a `Vec<Metric>`.
73        pub fn download_metrics(context: TrainingHandle) -> FFIResult<FutureHandle>;
74
75        pub fn poll_future_vec_metric(handle: FutureHandle) -> FFIResult<PollVec>;
76        pub fn take_future_vec_metric(handle: FutureHandle, data: &mut [Metric]) -> FFIResult<()>;
77    }
78}
79
80pub use ml::*;