1use crate::ontology::{
7 AgentAction, AgentCapability, Discoverable, SemanticRole, UiNode, WidgetSchema,
8};
9
10#[derive(Debug, Clone, Copy)]
12pub struct ComputeDispatch {
13 pub workgroups_x: u32,
14 pub workgroups_y: u32,
15 pub workgroups_z: u32,
16}
17
18impl ComputeDispatch {
19 pub fn new(x: u32, y: u32, z: u32) -> Self {
20 Self {
21 workgroups_x: x,
22 workgroups_y: y,
23 workgroups_z: z,
24 }
25 }
26
27 pub fn linear(count: u32, workgroup_size: u32) -> Self {
28 let groups = count.div_ceil(workgroup_size);
29 Self::new(groups, 1, 1)
30 }
31
32 pub fn grid_2d(width: u32, height: u32, workgroup_size: u32) -> Self {
33 let gx = width.div_ceil(workgroup_size);
34 let gy = height.div_ceil(workgroup_size);
35 Self::new(gx, gy, 1)
36 }
37
38 pub fn total_invocations(&self, workgroup_size: [u32; 3]) -> u64 {
39 self.workgroups_x as u64
40 * workgroup_size[0] as u64
41 * self.workgroups_y as u64
42 * workgroup_size[1] as u64
43 * self.workgroups_z as u64
44 * workgroup_size[2] as u64
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct ComputeBinding {
51 pub name: String,
52 pub group: u32,
53 pub binding: u32,
54 pub read_only: bool,
55}
56
57impl ComputeBinding {
58 pub fn storage(name: impl Into<String>, group: u32, binding: u32) -> Self {
59 Self {
60 name: name.into(),
61 group,
62 binding,
63 read_only: false,
64 }
65 }
66
67 pub fn read_only(name: impl Into<String>, group: u32, binding: u32) -> Self {
68 Self {
69 name: name.into(),
70 group,
71 binding,
72 read_only: true,
73 }
74 }
75}
76
77pub struct ComputeTask {
79 id: String,
80 label: String,
81 shader_source: String,
82 entry_point: String,
83 workgroup_size: [u32; 3],
84 bindings: Vec<ComputeBinding>,
85}
86
87impl ComputeTask {
88 pub fn new(
89 id: impl Into<String>,
90 label: impl Into<String>,
91 shader_source: impl Into<String>,
92 ) -> Self {
93 Self {
94 id: id.into(),
95 label: label.into(),
96 shader_source: shader_source.into(),
97 entry_point: "main".to_string(),
98 workgroup_size: [64, 1, 1],
99 bindings: Vec::new(),
100 }
101 }
102
103 pub fn entry_point(mut self, entry: impl Into<String>) -> Self {
104 self.entry_point = entry.into();
105 self
106 }
107
108 pub fn workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
109 self.workgroup_size = [x, y, z];
110 self
111 }
112
113 pub fn binding(mut self, binding: ComputeBinding) -> Self {
114 self.bindings.push(binding);
115 self
116 }
117
118 pub fn id(&self) -> &str {
119 &self.id
120 }
121
122 pub fn label(&self) -> &str {
123 &self.label
124 }
125
126 pub fn shader_source(&self) -> &str {
127 &self.shader_source
128 }
129
130 pub fn entry(&self) -> &str {
131 &self.entry_point
132 }
133
134 pub fn workgroup(&self) -> [u32; 3] {
135 self.workgroup_size
136 }
137
138 pub fn bindings(&self) -> &[ComputeBinding] {
139 &self.bindings
140 }
141}
142
143impl Discoverable for ComputeTask {
144 fn schema(&self) -> WidgetSchema {
145 WidgetSchema::new("ComputeTask", "A GPU compute task", SemanticRole::System)
146 }
147
148 fn capabilities(&self) -> Vec<AgentCapability> {
149 vec![AgentCapability::Custom("gpu_compute".into())]
150 }
151
152 fn actions(&self) -> Vec<AgentAction> {
153 vec![
154 AgentAction::simple("dispatch", "Execute the compute task", true),
155 AgentAction::simple("get_bindings", "List buffer bindings", false),
156 ]
157 }
158
159 fn semantic_role(&self) -> SemanticRole {
160 SemanticRole::System
161 }
162
163 fn agent_state(&self) -> serde_json::Value {
164 serde_json::json!({
165 "id": self.id,
166 "label": self.label,
167 "entry_point": self.entry_point,
168 "workgroup_size": self.workgroup_size,
169 "bindings": self.bindings.iter().map(|b| {
170 serde_json::json!({
171 "name": b.name,
172 "group": b.group,
173 "binding": b.binding,
174 "read_only": b.read_only,
175 })
176 }).collect::<Vec<_>>(),
177 })
178 }
179
180 fn execute_action(
181 &mut self,
182 action: &str,
183 _params: &serde_json::Value,
184 ) -> Result<serde_json::Value, String> {
185 match action {
186 "get_bindings" => Ok(serde_json::json!({
187 "bindings": self.bindings.iter().map(|b| &b.name).collect::<Vec<_>>()
188 })),
189 "dispatch" => {
190 Ok(serde_json::json!({ "status": "requires_gpu_context" }))
192 }
193 _ => Err(format!("Unknown action: {action}")),
194 }
195 }
196
197 fn agent_id(&self) -> Option<&str> {
198 Some(&self.id)
199 }
200}
201
202impl ComputeTask {
203 pub fn ui_node(&self) -> UiNode {
204 UiNode::new("ComputeTask", SemanticRole::System).with_id(&self.id)
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 #[test]
213 fn dispatch_linear() {
214 let d = ComputeDispatch::linear(1000, 64);
215 assert_eq!(d.workgroups_x, 16); assert_eq!(d.workgroups_y, 1);
217 }
218
219 #[test]
220 fn dispatch_grid_2d() {
221 let d = ComputeDispatch::grid_2d(256, 256, 16);
222 assert_eq!(d.workgroups_x, 16);
223 assert_eq!(d.workgroups_y, 16);
224 }
225
226 #[test]
227 fn dispatch_total_invocations() {
228 let d = ComputeDispatch::new(4, 4, 1);
229 assert_eq!(d.total_invocations([8, 8, 1]), 4 * 8 * 4 * 8);
230 }
231
232 #[test]
233 fn compute_binding_storage() {
234 let b = ComputeBinding::storage("output", 0, 0);
235 assert!(!b.read_only);
236 }
237
238 #[test]
239 fn compute_binding_readonly() {
240 let b = ComputeBinding::read_only("input", 0, 1);
241 assert!(b.read_only);
242 }
243
244 #[test]
245 fn compute_task_builder() {
246 let task = ComputeTask::new("add_vectors", "Vector addition", "@compute fn main() {}")
247 .entry_point("main")
248 .workgroup_size(256, 1, 1)
249 .binding(ComputeBinding::read_only("a", 0, 0))
250 .binding(ComputeBinding::read_only("b", 0, 1))
251 .binding(ComputeBinding::storage("result", 0, 2));
252 assert_eq!(task.bindings().len(), 3);
253 assert_eq!(task.workgroup(), [256, 1, 1]);
254 }
255
256 #[test]
257 fn compute_task_discoverable() {
258 let task = ComputeTask::new("t1", "Test", "");
259 assert_eq!(task.semantic_role(), SemanticRole::System);
260 let state = task.agent_state();
261 assert_eq!(state["id"], "t1");
262 }
263
264 #[test]
265 fn compute_task_execute_get_bindings() {
266 let mut task =
267 ComputeTask::new("t1", "Test", "").binding(ComputeBinding::storage("out", 0, 0));
268 let result = task.execute_action("get_bindings", &serde_json::json!({}));
269 assert!(result.is_ok());
270 }
271
272 #[test]
273 fn compute_task_unknown_action() {
274 let mut task = ComputeTask::new("t1", "Test", "");
275 let result = task.execute_action("nope", &serde_json::json!({}));
276 assert!(result.is_err());
277 }
278}