Skip to main content

agpu/
compute.rs

1//! Compute shader support — general-purpose GPU compute through the ontology.
2//!
3//! Provides [`ComputeTask`] for defining GPU compute operations with typed
4//! input/output buffers, and [`ComputeDispatch`] for execution parameters.
5
6use crate::ontology::{
7    AgentAction, AgentCapability, Discoverable, SemanticRole, UiNode, WidgetSchema,
8};
9
10/// Describes a compute dispatch invocation.
11#[derive(Debug, Clone, Copy)]
12pub struct ComputeDispatch {
13    pub workgroups_x: u32,
14    pub workgroups_y: u32,
15    pub workgroups_z: u32,
16}
17
18impl ComputeDispatch {
19    pub fn new(x: u32, y: u32, z: u32) -> Self {
20        Self {
21            workgroups_x: x,
22            workgroups_y: y,
23            workgroups_z: z,
24        }
25    }
26
27    pub fn linear(count: u32, workgroup_size: u32) -> Self {
28        let groups = count.div_ceil(workgroup_size);
29        Self::new(groups, 1, 1)
30    }
31
32    pub fn grid_2d(width: u32, height: u32, workgroup_size: u32) -> Self {
33        let gx = width.div_ceil(workgroup_size);
34        let gy = height.div_ceil(workgroup_size);
35        Self::new(gx, gy, 1)
36    }
37
38    pub fn total_invocations(&self, workgroup_size: [u32; 3]) -> u64 {
39        self.workgroups_x as u64
40            * workgroup_size[0] as u64
41            * self.workgroups_y as u64
42            * workgroup_size[1] as u64
43            * self.workgroups_z as u64
44            * workgroup_size[2] as u64
45    }
46}
47
48/// Describes a buffer binding for a compute task.
49#[derive(Debug, Clone)]
50pub struct ComputeBinding {
51    pub name: String,
52    pub group: u32,
53    pub binding: u32,
54    pub read_only: bool,
55}
56
57impl ComputeBinding {
58    pub fn storage(name: impl Into<String>, group: u32, binding: u32) -> Self {
59        Self {
60            name: name.into(),
61            group,
62            binding,
63            read_only: false,
64        }
65    }
66
67    pub fn read_only(name: impl Into<String>, group: u32, binding: u32) -> Self {
68        Self {
69            name: name.into(),
70            group,
71            binding,
72            read_only: true,
73        }
74    }
75}
76
77/// A compute task definition with shader source, bindings, and workgroup size.
78pub struct ComputeTask {
79    id: String,
80    label: String,
81    shader_source: String,
82    entry_point: String,
83    workgroup_size: [u32; 3],
84    bindings: Vec<ComputeBinding>,
85}
86
87impl ComputeTask {
88    pub fn new(
89        id: impl Into<String>,
90        label: impl Into<String>,
91        shader_source: impl Into<String>,
92    ) -> Self {
93        Self {
94            id: id.into(),
95            label: label.into(),
96            shader_source: shader_source.into(),
97            entry_point: "main".to_string(),
98            workgroup_size: [64, 1, 1],
99            bindings: Vec::new(),
100        }
101    }
102
103    pub fn entry_point(mut self, entry: impl Into<String>) -> Self {
104        self.entry_point = entry.into();
105        self
106    }
107
108    pub fn workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
109        self.workgroup_size = [x, y, z];
110        self
111    }
112
113    pub fn binding(mut self, binding: ComputeBinding) -> Self {
114        self.bindings.push(binding);
115        self
116    }
117
118    pub fn id(&self) -> &str {
119        &self.id
120    }
121
122    pub fn label(&self) -> &str {
123        &self.label
124    }
125
126    pub fn shader_source(&self) -> &str {
127        &self.shader_source
128    }
129
130    pub fn entry(&self) -> &str {
131        &self.entry_point
132    }
133
134    pub fn workgroup(&self) -> [u32; 3] {
135        self.workgroup_size
136    }
137
138    pub fn bindings(&self) -> &[ComputeBinding] {
139        &self.bindings
140    }
141}
142
143impl Discoverable for ComputeTask {
144    fn schema(&self) -> WidgetSchema {
145        WidgetSchema::new("ComputeTask", "A GPU compute task", SemanticRole::System)
146    }
147
148    fn capabilities(&self) -> Vec<AgentCapability> {
149        vec![AgentCapability::Custom("gpu_compute".into())]
150    }
151
152    fn actions(&self) -> Vec<AgentAction> {
153        vec![
154            AgentAction::simple("dispatch", "Execute the compute task", true),
155            AgentAction::simple("get_bindings", "List buffer bindings", false),
156        ]
157    }
158
159    fn semantic_role(&self) -> SemanticRole {
160        SemanticRole::System
161    }
162
163    fn agent_state(&self) -> serde_json::Value {
164        serde_json::json!({
165            "id": self.id,
166            "label": self.label,
167            "entry_point": self.entry_point,
168            "workgroup_size": self.workgroup_size,
169            "bindings": self.bindings.iter().map(|b| {
170                serde_json::json!({
171                    "name": b.name,
172                    "group": b.group,
173                    "binding": b.binding,
174                    "read_only": b.read_only,
175                })
176            }).collect::<Vec<_>>(),
177        })
178    }
179
180    fn execute_action(
181        &mut self,
182        action: &str,
183        _params: &serde_json::Value,
184    ) -> Result<serde_json::Value, String> {
185        match action {
186            "get_bindings" => Ok(serde_json::json!({
187                "bindings": self.bindings.iter().map(|b| &b.name).collect::<Vec<_>>()
188            })),
189            "dispatch" => {
190                // Actual GPU dispatch requires GpuContext; this is the ontology interface
191                Ok(serde_json::json!({ "status": "requires_gpu_context" }))
192            }
193            _ => Err(format!("Unknown action: {action}")),
194        }
195    }
196
197    fn agent_id(&self) -> Option<&str> {
198        Some(&self.id)
199    }
200}
201
202impl ComputeTask {
203    pub fn ui_node(&self) -> UiNode {
204        UiNode::new("ComputeTask", SemanticRole::System).with_id(&self.id)
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[test]
213    fn dispatch_linear() {
214        let d = ComputeDispatch::linear(1000, 64);
215        assert_eq!(d.workgroups_x, 16); // ceil(1000/64) = 16
216        assert_eq!(d.workgroups_y, 1);
217    }
218
219    #[test]
220    fn dispatch_grid_2d() {
221        let d = ComputeDispatch::grid_2d(256, 256, 16);
222        assert_eq!(d.workgroups_x, 16);
223        assert_eq!(d.workgroups_y, 16);
224    }
225
226    #[test]
227    fn dispatch_total_invocations() {
228        let d = ComputeDispatch::new(4, 4, 1);
229        assert_eq!(d.total_invocations([8, 8, 1]), 4 * 8 * 4 * 8);
230    }
231
232    #[test]
233    fn compute_binding_storage() {
234        let b = ComputeBinding::storage("output", 0, 0);
235        assert!(!b.read_only);
236    }
237
238    #[test]
239    fn compute_binding_readonly() {
240        let b = ComputeBinding::read_only("input", 0, 1);
241        assert!(b.read_only);
242    }
243
244    #[test]
245    fn compute_task_builder() {
246        let task = ComputeTask::new("add_vectors", "Vector addition", "@compute fn main() {}")
247            .entry_point("main")
248            .workgroup_size(256, 1, 1)
249            .binding(ComputeBinding::read_only("a", 0, 0))
250            .binding(ComputeBinding::read_only("b", 0, 1))
251            .binding(ComputeBinding::storage("result", 0, 2));
252        assert_eq!(task.bindings().len(), 3);
253        assert_eq!(task.workgroup(), [256, 1, 1]);
254    }
255
256    #[test]
257    fn compute_task_discoverable() {
258        let task = ComputeTask::new("t1", "Test", "");
259        assert_eq!(task.semantic_role(), SemanticRole::System);
260        let state = task.agent_state();
261        assert_eq!(state["id"], "t1");
262    }
263
264    #[test]
265    fn compute_task_execute_get_bindings() {
266        let mut task =
267            ComputeTask::new("t1", "Test", "").binding(ComputeBinding::storage("out", 0, 0));
268        let result = task.execute_action("get_bindings", &serde_json::json!({}));
269        assert!(result.is_ok());
270    }
271
272    #[test]
273    fn compute_task_unknown_action() {
274        let mut task = ComputeTask::new("t1", "Test", "");
275        let result = task.execute_action("nope", &serde_json::json!({}));
276        assert!(result.is_err());
277    }
278}