1use std::collections::HashMap;
19
20use anyhow::Result;
21use rlx_ir::hir::{HirModule, HirNodeId};
22use rlx_ir::{DType, GraphModule, Shape};
23
24use crate::profile::CompileProfile;
25use crate::value::FlowValue;
26use crate::weight::WeightSource;
27
28#[derive(Debug, Clone, Copy)]
30pub struct GdnInputSlots {
31 pub q: HirNodeId,
32 pub k: HirNodeId,
33 pub v: HirNodeId,
34 pub g: HirNodeId,
35 pub beta: HirNodeId,
36}
37
38#[derive(Debug, Default)]
40pub struct FlowState {
41 pub rope_cos: Option<HirNodeId>,
42 pub rope_sin: Option<HirNodeId>,
43 pub zero_beta: Option<HirNodeId>,
44 pub embed_weight: Option<HirNodeId>,
45 pub hidden_shape: Option<Shape>,
46 pub decode: Option<DecodeBindings>,
47 pub residual_skip: Option<HirNodeId>,
48 pub residual_shape: Option<Shape>,
49 pub streams: HashMap<String, FlowValue>,
51 pub inputs: HashMap<String, (HirNodeId, Shape)>,
53 pub named: HashMap<String, HirNodeId>,
55 pub gdn: Option<GdnInputSlots>,
57 pub loaded_params: HashMap<String, HirNodeId>,
60}
61
62#[derive(Debug, Clone)]
64pub struct DecodeBindings {
65 pub cos: HirNodeId,
66 pub sin: HirNodeId,
67 pub mask: Option<HirNodeId>,
68 pub past_k: Vec<HirNodeId>,
69 pub past_v: Vec<HirNodeId>,
70}
71
72pub struct FlowCtx<'a> {
74 pub(crate) module: GraphModule,
75 pub(crate) params: &'a mut HashMap<String, Vec<f32>>,
76 pub(crate) weights: &'a mut dyn WeightSource,
77 pub(crate) profile: &'a CompileProfile,
78 pub(crate) state: &'a mut FlowState,
79}
80
81impl FlowCtx<'_> {
82 pub fn hir(&mut self) -> &mut HirModule {
83 self.module
84 .as_hir_mut()
85 .expect("flow context requires HIR stage")
86 }
87
88 pub fn node_shape(&self, id: HirNodeId) -> Result<Shape> {
89 Ok(self
90 .module
91 .as_hir()
92 .ok_or_else(|| anyhow::anyhow!("flow context requires HIR stage"))?
93 .node(id)
94 .shape
95 .clone())
96 }
97
98 pub fn load_param(&mut self, key: &str, transpose: bool) -> Result<HirNodeId> {
99 let cache_key = param_cache_key(key, transpose);
100 if let Some(&id) = self.state.loaded_params.get(&cache_key) {
101 return Ok(id);
102 }
103 let (data, shape) = self.weights.take(key, transpose)?;
104 let ir_shape = Shape::new(&shape, DType::F32);
105 let id = self.hir().param(key, ir_shape);
106 self.params.insert(key.to_string(), data);
107 self.state.loaded_params.insert(cache_key, id);
108 Ok(id)
109 }
110
111 pub fn synth_param(&mut self, name: &str, data: Vec<f32>, shape: Shape) -> HirNodeId {
112 let id = self.hir().param(name, shape);
113 self.params.insert(name.to_string(), data);
114 id
115 }
116
117 pub fn synth_zeros(&mut self, name: &str, len: usize) -> HirNodeId {
118 self.synth_param(name, vec![0f32; len], Shape::new(&[len], DType::F32))
119 }
120
121 pub fn input(&mut self, name: &str, shape: Shape) -> HirNodeId {
122 self.hir().input(name, shape)
123 }
124
125 pub fn wrap(&self, id: HirNodeId, shape: Shape) -> FlowValue {
126 FlowValue::new(id, shape)
127 }
128}
129
130fn param_cache_key(key: &str, transpose: bool) -> String {
131 if transpose {
132 format!("{key}\0t")
133 } else {
134 key.to_string()
135 }
136}