1define_api_id!(0xaed6_0109_4f07_3c82, "ml-v3");
2
3pub use super::ml_v1::FixedBufferString;
4pub use super::ml_v1::FutureHandle;
5pub use super::ml_v1::InferenceHandle;
6pub use super::ml_v1::PollVec;
7pub use super::ml_v1::TrainingHandle;
8use crate::FFIResult;
9use bytemuck::Pod;
10use bytemuck::Zeroable;
11
12#[repr(C)]
13#[derive(Copy, Clone, Debug, Pod, Zeroable)]
14pub struct Metric {
15 pub name: FixedBufferString<64>,
16 pub value: f32,
17}
18
19#[allow(clippy::too_many_arguments)]
20#[ark_api_macros::ark_bindgen(imports = "ark-ml-v3")]
21mod ml {
22 use super::*;
23
24 extern "C" {
25 pub fn start_training(
31 hive_url: &str,
32 hive_port: u32,
33 game_name: &str,
34 experiment_name: &str,
35 num_features: u32,
36 num_actions: u32,
37 num_remote_workers: u32,
38 config: &str,
39 checkpoint: &str,
40 training_duration_in_seconds: u64,
41 ) -> FFIResult<FutureHandle>;
42
43 pub fn push_training_observation(
49 context: TrainingHandle,
50 id: u64,
51 episode_state: u32,
53 reward: f32,
54 features: &[f32],
55 actions: &[f32],
56 metadata: &[Metric],
57 ) -> FFIResult<()>;
58
59 pub fn push_inference_observation(
61 context: InferenceHandle,
62 id: u64,
63 episode_state: u32,
65 reward: f32,
66 features: &[f32],
67 metadata: &[Metric],
68 ) -> FFIResult<()>;
69
70 pub fn download_metrics(context: TrainingHandle) -> FFIResult<FutureHandle>;
74
75 pub fn poll_future_vec_metric(handle: FutureHandle) -> FFIResult<PollVec>;
76 pub fn take_future_vec_metric(handle: FutureHandle, data: &mut [Metric]) -> FFIResult<()>;
77 }
78}
79
80pub use ml::*;