vyre_driver/
device_convergence.rs1#[derive(Clone, Copy, Debug, Eq, PartialEq)]
5pub enum ConvergenceReadbackPolicy {
6 FinalFlagOnly,
8}
9
10#[derive(Clone, Copy, Debug, Eq, PartialEq)]
12pub struct DeviceConvergencePlan {
13 pub max_device_iterations: u32,
15 pub host_sync_points: u32,
17 pub changed_flag_readback_bytes: u32,
19 pub host_iteration_polls: u32,
21 pub readback_policy: ConvergenceReadbackPolicy,
23}
24
25#[derive(Clone, Debug, Eq, PartialEq)]
27pub enum DeviceConvergencePlanError {
28 EmptyIterationBudget,
30 InvalidChangedFlagWidth {
32 bytes: u32,
34 },
35 HostPolledConvergence {
37 polls: u32,
39 },
40}
41
42impl std::fmt::Display for DeviceConvergencePlanError {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 match self {
45 Self::EmptyIterationBudget => f.write_str(
46 "device convergence iteration budget is zero. Fix: use at least one device iteration.",
47 ),
48 Self::InvalidChangedFlagWidth { bytes } => write!(
49 f,
50 "device convergence changed-flag width is {bytes} bytes. Fix: use a 4-byte device u32 changed flag."
51 ),
52 Self::HostPolledConvergence { polls } => write!(
53 f,
54 "device convergence requested {polls} host iteration polls. Fix: keep convergence detection device-side and read only the final changed flag."
55 ),
56 }
57 }
58}
59
60impl std::error::Error for DeviceConvergencePlanError {}
61
62pub fn plan_device_convergence(
70 max_device_iterations: u32,
71 changed_flag_bytes: u32,
72 requested_host_iteration_polls: u32,
73) -> Result<DeviceConvergencePlan, DeviceConvergencePlanError> {
74 if max_device_iterations == 0 {
75 return Err(DeviceConvergencePlanError::EmptyIterationBudget);
76 }
77 if changed_flag_bytes != 4 {
78 return Err(DeviceConvergencePlanError::InvalidChangedFlagWidth {
79 bytes: changed_flag_bytes,
80 });
81 }
82 if requested_host_iteration_polls != 0 {
83 return Err(DeviceConvergencePlanError::HostPolledConvergence {
84 polls: requested_host_iteration_polls,
85 });
86 }
87
88 Ok(DeviceConvergencePlan {
89 max_device_iterations,
90 host_sync_points: 1,
91 changed_flag_readback_bytes: changed_flag_bytes,
92 host_iteration_polls: 0,
93 readback_policy: ConvergenceReadbackPolicy::FinalFlagOnly,
94 })
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100
101 #[test]
102 fn convergence_plan_reads_final_flag_once() {
103 let plan = plan_device_convergence(128, 4, 0).expect("Fix: valid plan should build");
104
105 assert_eq!(plan.max_device_iterations, 128);
106 assert_eq!(plan.host_sync_points, 1);
107 assert_eq!(plan.changed_flag_readback_bytes, 4);
108 assert_eq!(plan.host_iteration_polls, 0);
109 assert_eq!(
110 plan.readback_policy,
111 ConvergenceReadbackPolicy::FinalFlagOnly
112 );
113 }
114
115 #[test]
116 fn convergence_plan_rejects_empty_iteration_budget() {
117 let err = plan_device_convergence(0, 4, 0).expect_err("zero iterations cannot converge");
118
119 assert_eq!(err, DeviceConvergencePlanError::EmptyIterationBudget);
120 assert!(err.to_string().contains("at least one device iteration"));
121 }
122
123 #[test]
124 fn convergence_plan_rejects_wrong_changed_flag_width() {
125 let err = plan_device_convergence(8, 1, 0).expect_err("changed flag must be a u32");
126
127 assert_eq!(
128 err,
129 DeviceConvergencePlanError::InvalidChangedFlagWidth { bytes: 1 }
130 );
131 assert!(err.to_string().contains("4-byte device u32 changed flag"));
132 }
133
134 #[test]
135 fn convergence_plan_rejects_host_polled_iterations() {
136 let err = plan_device_convergence(8, 4, 8)
137 .expect_err("host polling every iteration is forbidden");
138
139 assert_eq!(
140 err,
141 DeviceConvergencePlanError::HostPolledConvergence { polls: 8 }
142 );
143 assert!(err.to_string().contains("read only the final changed flag"));
144 }
145
146 #[test]
147 fn generated_convergence_iteration_budgets_preserve_final_only_contract() {
148 for max_device_iterations in 1..=4_096 {
149 let plan = plan_device_convergence(max_device_iterations, 4, 0)
150 .expect("Fix: generated nonzero iteration budgets should plan");
151 assert_eq!(plan.max_device_iterations, max_device_iterations);
152 assert_eq!(plan.host_sync_points, 1);
153 assert_eq!(plan.changed_flag_readback_bytes, 4);
154 assert_eq!(plan.host_iteration_polls, 0);
155 assert_eq!(
156 plan.readback_policy,
157 ConvergenceReadbackPolicy::FinalFlagOnly
158 );
159 }
160 }
161}