1use super::config::{SAM_EMBED_HW, SamEncoderConfig};
34use super::preprocess::{SamPreprocessWeights, extract_preprocess_weights};
35use anyhow::{Result, anyhow, ensure};
36use rlx_core::vision_ops_ir::{bhwc_to_nchw, conv2d_bias, conv2d_no_bias, layer_norm2d_nchw};
37use rlx_core::weight_map::WeightMap;
38use rlx_ir::HirGraphExt;
39use rlx_ir::hir::{HirModule, HirMut, HirNodeId};
40use rlx_ir::*;
41use std::collections::HashMap;
42
43struct SamBuilder {
44 hir: HirModule,
45 params: HashMap<String, Vec<f32>>,
46}
47
48impl SamBuilder {
49 fn new(name: &str) -> Self {
50 Self {
51 hir: HirModule::new(name),
52 params: HashMap::new(),
53 }
54 }
55
56 fn m(&mut self) -> HirMut<'_> {
57 HirMut::new(&mut self.hir)
58 }
59}
60
61#[allow(dead_code)]
62fn lower_hir(hir: HirModule) -> Result<Graph> {
63 Graph::from_hir(hir).map_err(|e| anyhow!("{e}"))
64}
65
66pub fn build_sam_encoder_hir(
73 cfg: &SamEncoderConfig,
74 weights: &mut WeightMap,
75) -> Result<(HirModule, HashMap<String, Vec<f32>>, SamPreprocessWeights)> {
76 let mut b = SamBuilder::new("sam_image_encoder");
77 let f = DType::F32;
78
79 let preprocess = extract_preprocess_weights(weights, cfg)?;
83
84 let e = cfg.embed_dim;
85 let nh = cfg.num_heads;
86 let dh = cfg.head_dim();
87 let scale = 1.0 / (dh as f32).sqrt();
88 let eps = cfg.layer_norm_eps as f32;
89 let hw = SAM_EMBED_HW;
90 let s = hw * hw; let hidden_input = b.m().input("hidden", Shape::new(&[1, s, e], f));
94
95 let mut x = hidden_input;
96 for layer_idx in 0..cfg.depth {
97 let lp = format!("image_encoder.blocks.{layer_idx}");
98 let is_global = cfg.global_attn_indexes.contains(&layer_idx);
99 let ws = if is_global { 0 } else { cfg.window_size };
100
101 let n1_g = load_p(&mut b, weights, &format!("{lp}.norm1.weight"), false)?;
103 let n1_b = load_p(&mut b, weights, &format!("{lp}.norm1.bias"), false)?;
104 let normed = b.m().ln(x, n1_g, n1_b, eps);
105
106 let attn_out = if ws == 0 {
108 attention_global(
109 &mut b,
110 weights,
111 &lp,
112 normed,
113 e,
114 nh,
115 dh,
116 scale,
117 hw,
118 cfg.use_rel_pos,
119 cfg.qkv_bias,
120 )?
121 } else {
122 attention_windowed(
123 &mut b,
124 weights,
125 &lp,
126 normed,
127 e,
128 nh,
129 dh,
130 scale,
131 hw,
132 ws,
133 cfg.use_rel_pos,
134 cfg.qkv_bias,
135 )?
136 };
137
138 x = b.m().add(x, attn_out);
140
141 let n2_g = load_p(&mut b, weights, &format!("{lp}.norm2.weight"), false)?;
143 let n2_b = load_p(&mut b, weights, &format!("{lp}.norm2.bias"), false)?;
144 let normed2 = b.m().ln(x, n2_g, n2_b, eps);
145
146 let fc1_w = load_p(&mut b, weights, &format!("{lp}.mlp.lin1.weight"), true)?;
147 let fc1_b = load_p(&mut b, weights, &format!("{lp}.mlp.lin1.bias"), false)?;
148 let fc2_w = load_p(&mut b, weights, &format!("{lp}.mlp.lin2.weight"), true)?;
149 let fc2_b = load_p(&mut b, weights, &format!("{lp}.mlp.lin2.bias"), false)?;
150
151 let up_mm = b.m().mm(normed2, fc1_w);
152 let up = b.m().add(up_mm, fc1_b);
153 let act = b.m().gelu(up);
157 let down_mm = b.m().mm(act, fc2_w);
158 let ffn = b.m().add(down_mm, fc2_b);
159
160 x = b.m().add(x, ffn);
161 }
162
163 let oc = cfg.out_chans;
165 let nchw = bhwc_to_nchw(&mut b.m(), x, 1, hw, hw, e);
166 let c1_w = load_p(&mut b, weights, "image_encoder.neck.0.weight", false)?;
167 let c1_b = load_p(&mut b, weights, "image_encoder.neck.0.bias", false)?;
168 let feat = conv2d_bias(
169 &mut b.m(),
170 nchw,
171 c1_w,
172 c1_b,
173 1,
174 oc,
175 1,
176 1,
177 [1, 1],
178 [0, 0],
179 hw,
180 hw,
181 );
182 let ln1_g = load_p(&mut b, weights, "image_encoder.neck.1.weight", false)?;
183 let ln1_b = load_p(&mut b, weights, "image_encoder.neck.1.bias", false)?;
184 let feat = layer_norm2d_nchw(&mut b.m(), feat, ln1_g, ln1_b, eps);
185 let c2_w = load_p(&mut b, weights, "image_encoder.neck.2.weight", false)?;
186 let feat = conv2d_no_bias(&mut b.m(), feat, c2_w, 1, oc, 3, 3, [1, 1], [1, 1], hw, hw);
187 let ln2_g = load_p(&mut b, weights, "image_encoder.neck.3.weight", false)?;
188 let ln2_b = load_p(&mut b, weights, "image_encoder.neck.3.bias", false)?;
189 let out = layer_norm2d_nchw(&mut b.m(), feat, ln2_g, ln2_b, eps);
190
191 b.hir.set_outputs(vec![out]);
192
193 Ok((b.hir, b.params, preprocess))
194}
195
196pub fn build_sam_encoder_graph(
198 cfg: &SamEncoderConfig,
199 weights: &mut WeightMap,
200) -> Result<(Graph, HashMap<String, Vec<f32>>, SamPreprocessWeights)> {
201 let built = super::flow::build_sam_encoder_built(cfg, weights)?;
202 let (graph, params) = rlx_core::flow_util::graph_from_built(built.model)?;
203 Ok((graph, params, built.preprocess))
204}
205
206#[allow(clippy::too_many_arguments)]
208fn attention_global(
209 sb: &mut SamBuilder,
210 w: &mut WeightMap,
211 lp: &str,
212 x: HirNodeId, e: usize,
214 nh: usize,
215 dh: usize,
216 scale: f32,
217 hw: usize,
218 use_rel_pos: bool,
219 qkv_bias: bool,
220) -> Result<HirNodeId> {
221 let s = hw * hw;
222 decomposed_attention(
223 sb,
224 w,
225 lp,
226 x,
227 e,
228 nh,
229 dh,
230 scale,
231 hw,
232 hw,
233 s,
234 1,
235 use_rel_pos,
236 qkv_bias,
237 )
238}
239
240#[allow(clippy::too_many_arguments)]
243fn attention_windowed(
244 sb: &mut SamBuilder,
245 w: &mut WeightMap,
246 lp: &str,
247 x: HirNodeId, e: usize,
249 nh: usize,
250 dh: usize,
251 scale: f32,
252 hw: usize,
253 ws: usize,
254 use_rel_pos: bool,
255 qkv_bias: bool,
256) -> Result<HirNodeId> {
257 let bhwc = sb.m().reshape_(x, vec![1, hw as i64, hw as i64, e as i64]);
259
260 let pad = (ws - hw % ws) % ws;
261 let hw_p = hw + pad;
262 let n_win_per_side = hw_p / ws;
263 let n_win = n_win_per_side * n_win_per_side;
264
265 let padded = if pad > 0 {
267 let z_h = pad_zero_param(sb, &format!("{lp}.attn._pad_h"), &[1, pad, hw, e]);
268 let p1 = sb.m().concat_(vec![bhwc, z_h], 1); let z_w = pad_zero_param(sb, &format!("{lp}.attn._pad_w"), &[1, hw_p, pad, e]);
270 sb.m().concat_(vec![p1, z_w], 2) } else {
272 bhwc
273 };
274
275 let reshaped = sb.m().reshape_(
278 padded,
279 vec![
280 1,
281 n_win_per_side as i64,
282 ws as i64,
283 n_win_per_side as i64,
284 ws as i64,
285 e as i64,
286 ],
287 );
288 let transposed = sb.m().transpose_(reshaped, vec![0, 1, 3, 2, 4, 5]);
289 let windowed = sb.m().reshape_(
290 transposed,
291 vec![n_win as i64, ws as i64, ws as i64, e as i64],
292 );
293 let win_flat = sb
295 .m()
296 .reshape_(windowed, vec![n_win as i64, (ws * ws) as i64, e as i64]);
297
298 let attn_out = decomposed_attention(
301 sb,
302 w,
303 lp,
304 win_flat,
305 e,
306 nh,
307 dh,
308 scale,
309 ws,
310 ws,
311 ws * ws,
312 n_win,
313 use_rel_pos,
314 qkv_bias,
315 )?;
316 let un = sb
321 .m()
322 .reshape_(attn_out, vec![n_win as i64, ws as i64, ws as i64, e as i64]);
323 let un = sb.m().reshape_(
324 un,
325 vec![
326 1,
327 n_win_per_side as i64,
328 n_win_per_side as i64,
329 ws as i64,
330 ws as i64,
331 e as i64,
332 ],
333 );
334 let un = sb.m().transpose_(un, vec![0, 1, 3, 2, 4, 5]);
335 let un = sb
336 .m()
337 .reshape_(un, vec![1, hw_p as i64, hw_p as i64, e as i64]);
338 let un = if pad > 0 {
340 let cropped_h = sb.m().narrow_(un, 1, 0, hw);
341 sb.m().narrow_(cropped_h, 2, 0, hw)
342 } else {
343 un
344 };
345 Ok(sb.m().reshape_(un, vec![1, (hw * hw) as i64, e as i64]))
347}
348
349#[allow(clippy::too_many_arguments)]
356fn decomposed_attention(
357 sb: &mut SamBuilder,
358 w: &mut WeightMap,
359 lp: &str,
360 x: HirNodeId, e: usize,
362 nh: usize,
363 dh: usize,
364 scale: f32,
365 h: usize,
366 w_dim: usize,
367 s: usize, batch: usize,
369 use_rel_pos: bool,
370 qkv_bias: bool,
371) -> Result<HirNodeId> {
372 let qkv_w_node = load_p(sb, w, &format!("{lp}.attn.qkv.weight"), true)?;
377 let qkv_b_node = if qkv_bias {
378 Some(load_p(sb, w, &format!("{lp}.attn.qkv.bias"), false)?)
379 } else {
380 None
381 };
382 let qkv_mm = sb.m().mm(x, qkv_w_node); let qkv = if let Some(b) = qkv_b_node {
384 sb.m().add(qkv_mm, b)
385 } else {
386 qkv_mm
387 };
388
389 let qkv5 = sb
393 .m()
394 .reshape_(qkv, vec![batch as i64, s as i64, 3, nh as i64, dh as i64]);
395 let qkv_perm = sb.m().transpose_(qkv5, vec![2, 0, 3, 1, 4]); let qkv_flat = sb
397 .m()
398 .reshape_(qkv_perm, vec![3, (batch * nh) as i64, s as i64, dh as i64]);
399 let q = sb.m().narrow_(qkv_flat, 0, 0, 1);
400 let q = sb
401 .m()
402 .reshape_(q, vec![(batch * nh) as i64, s as i64, dh as i64]);
403 let k = sb.m().narrow_(qkv_flat, 0, 1, 1);
404 let k = sb
405 .m()
406 .reshape_(k, vec![(batch * nh) as i64, s as i64, dh as i64]);
407 let v = sb.m().narrow_(qkv_flat, 0, 2, 1);
408 let v = sb
409 .m()
410 .reshape_(v, vec![(batch * nh) as i64, s as i64, dh as i64]);
411
412 let scale_node = scalar_param(sb, &format!("{lp}.attn._scale"), scale);
414 let q_scaled = sb.m().mul(q, scale_node);
415 let k_t = sb.m().transpose_(k, vec![0, 2, 1]); let scores = sb.m().mm(q_scaled, k_t); let scores = if use_rel_pos {
420 let (mut r_h_data, mut r_w_data) = extract_rel_pos(w, lp, h, w_dim, dh)?;
424 if rlx_ir::env::flag("RLX_SAM_DEBUG_ZERO_RELPOS") {
429 r_h_data.iter_mut().for_each(|v| *v = 0.0);
430 r_w_data.iter_mut().for_each(|v| *v = 0.0);
431 }
432 if rlx_ir::env::flag("RLX_SAM_DEBUG_ZERO_RELH") {
433 r_h_data.iter_mut().for_each(|v| *v = 0.0);
434 }
435 if rlx_ir::env::flag("RLX_SAM_DEBUG_ZERO_RELW") {
436 r_w_data.iter_mut().for_each(|v| *v = 0.0);
437 }
438 let r_h_node = const_param(
439 sb,
440 &format!("{lp}.attn._rel_h_indexed"),
441 &[h, h, dh],
442 r_h_data,
443 );
444 let r_w_node = const_param(
445 sb,
446 &format!("{lp}.attn._rel_w_indexed"),
447 &[w_dim, w_dim, dh],
448 r_w_data,
449 );
450 add_decomposed_rel_pos(sb, scores, q, r_h_node, r_w_node, batch, nh, h, w_dim, dh)?
451 } else {
452 scores
453 };
454
455 let attn_w = sb.m().sm(scores, -1);
457
458 let attn_v = sb.m().mm(attn_w, v);
460
461 let reshaped = sb
463 .m()
464 .reshape_(attn_v, vec![batch as i64, nh as i64, s as i64, dh as i64]);
465 let perm = sb.m().transpose_(reshaped, vec![0, 2, 1, 3]); let merged = sb
467 .m()
468 .reshape_(perm, vec![batch as i64, s as i64, e as i64]);
469
470 let proj_w = load_p(sb, w, &format!("{lp}.attn.proj.weight"), true)?;
472 let proj_b = load_p(sb, w, &format!("{lp}.attn.proj.bias"), false)?;
473 let proj_mm = sb.m().mm(merged, proj_w);
474 Ok(sb.m().add(proj_mm, proj_b))
475}
476
477#[allow(clippy::too_many_arguments)]
486fn add_decomposed_rel_pos(
487 sb: &mut SamBuilder,
488 scores: HirNodeId, q: HirNodeId, r_h: HirNodeId, r_w: HirNodeId, batch: usize,
493 nh: usize,
494 h: usize,
495 w: usize,
496 dh: usize,
497) -> Result<HirNodeId> {
498 let bh = batch * nh;
499 let r_q = sb
501 .m()
502 .reshape_(q, vec![bh as i64, h as i64, w as i64, dh as i64]);
503
504 let mut rel_h_slices: Vec<HirNodeId> = Vec::with_capacity(h);
513 for h_q in 0..h {
514 let rq_slice = sb.m().narrow_(r_q, 1, h_q, 1); let rq_slice = sb
517 .m()
518 .reshape_(rq_slice, vec![bh as i64, w as i64, dh as i64]);
519 let rh_slice = sb.m().narrow_(r_h, 0, h_q, 1); let rh_slice = sb.m().reshape_(rh_slice, vec![h as i64, dh as i64]); let rh_t = sb.m().transpose_(rh_slice, vec![1, 0]); let mm = sb.m().mm(rq_slice, rh_t); let mm5 = sb.m().reshape_(mm, vec![bh as i64, 1, w as i64, h as i64]);
526 rel_h_slices.push(mm5);
527 }
528 let rel_h_4d = sb.m().concat_(rel_h_slices, 1); let mut rel_w_slices: Vec<HirNodeId> = Vec::with_capacity(w);
532 for w_q in 0..w {
533 let rq_slice = sb.m().narrow_(r_q, 2, w_q, 1); let rq_slice = sb
535 .m()
536 .reshape_(rq_slice, vec![bh as i64, h as i64, dh as i64]);
537 let rw_slice = sb.m().narrow_(r_w, 0, w_q, 1); let rw_slice = sb.m().reshape_(rw_slice, vec![w as i64, dh as i64]); let rw_t = sb.m().transpose_(rw_slice, vec![1, 0]); let mm = sb.m().mm(rq_slice, rw_t); let mm5 = sb.m().reshape_(mm, vec![bh as i64, h as i64, 1, w as i64]);
542 rel_w_slices.push(mm5);
543 }
544 let rel_w_4d = sb.m().concat_(rel_w_slices, 2); let scores_5d = sb.m().reshape_(
557 scores,
558 vec![bh as i64, h as i64, w as i64, h as i64, w as i64],
559 );
560 let rel_h_5d = sb
561 .m()
562 .reshape_(rel_h_4d, vec![bh as i64, h as i64, w as i64, h as i64, 1]);
563 let rel_h_tiled = {
564 let mut copies = Vec::with_capacity(w);
565 for _ in 0..w {
566 copies.push(rel_h_5d);
567 }
568 sb.m().concat_(copies, 4) };
570 let rel_w_5d = sb
571 .m()
572 .reshape_(rel_w_4d, vec![bh as i64, h as i64, w as i64, 1, w as i64]);
573 let rel_w_tiled = {
574 let mut copies = Vec::with_capacity(h);
575 for _ in 0..h {
576 copies.push(rel_w_5d);
577 }
578 sb.m().concat_(copies, 3) };
580 let s1 = sb.m().add(scores_5d, rel_h_tiled);
581 let s2 = sb.m().add(s1, rel_w_tiled);
582 Ok(sb
583 .m()
584 .reshape_(s2, vec![bh as i64, (h * w) as i64, (h * w) as i64]))
585}
586
587fn extract_rel_pos(
594 weights: &mut WeightMap,
595 lp: &str,
596 h: usize,
597 w: usize,
598 dh: usize,
599) -> Result<(Vec<f32>, Vec<f32>)> {
600 let (rel_h_raw, rh_shape) = weights.take(&format!("{lp}.attn.rel_pos_h"))?;
601 let (rel_w_raw, rw_shape) = weights.take(&format!("{lp}.attn.rel_pos_w"))?;
602 ensure!(
603 rh_shape == vec![2 * h - 1, dh],
604 "{lp}.attn.rel_pos_h expected [{}, {dh}], got {rh_shape:?}",
605 2 * h - 1
606 );
607 ensure!(
608 rw_shape == vec![2 * w - 1, dh],
609 "{lp}.attn.rel_pos_w expected [{}, {dh}], got {rw_shape:?}",
610 2 * w - 1
611 );
612
613 let mut r_h = vec![0f32; h * h * dh];
614 for q in 0..h {
615 for k in 0..h {
616 let idx = (q as isize - k as isize + (h as isize - 1)) as usize;
617 let src = &rel_h_raw[idx * dh..(idx + 1) * dh];
618 let dst = &mut r_h[(q * h + k) * dh..(q * h + k + 1) * dh];
619 dst.copy_from_slice(src);
620 }
621 }
622 let mut r_w = vec![0f32; w * w * dh];
623 for q in 0..w {
624 for k in 0..w {
625 let idx = (q as isize - k as isize + (w as isize - 1)) as usize;
626 let src = &rel_w_raw[idx * dh..(idx + 1) * dh];
627 let dst = &mut r_w[(q * w + k) * dh..(q * w + k + 1) * dh];
628 dst.copy_from_slice(src);
629 }
630 }
631 Ok((r_h, r_w))
632}
633
634pub struct NeckWeights {
640 pub conv1_w: Vec<f32>, pub ln1_g: Vec<f32>, pub ln1_b: Vec<f32>,
643 pub conv2_w: Vec<f32>, pub ln2_g: Vec<f32>,
645 pub ln2_b: Vec<f32>,
646 pub embed_dim: usize,
647 pub out_chans: usize,
648 pub eps: f32,
649}
650
651#[allow(dead_code)]
652fn extract_neck_weights(weights: &mut WeightMap, cfg: &SamEncoderConfig) -> Result<NeckWeights> {
653 let (conv1_w_raw, c1_shape) = weights.take("image_encoder.neck.0.weight")?;
654 ensure!(
655 c1_shape == vec![cfg.out_chans, cfg.embed_dim, 1, 1],
656 "neck.0.weight expected [{}, {}, 1, 1], got {c1_shape:?}",
657 cfg.out_chans,
658 cfg.embed_dim
659 );
660 let conv1_w = conv1_w_raw; let (ln1_g, _) = weights.take("image_encoder.neck.1.weight")?;
662 let (ln1_b, _) = weights.take("image_encoder.neck.1.bias")?;
663 let (conv2_w, c2_shape) = weights.take("image_encoder.neck.2.weight")?;
664 ensure!(
665 c2_shape == vec![cfg.out_chans, cfg.out_chans, 3, 3],
666 "neck.2.weight expected [{}, {}, 3, 3], got {c2_shape:?}",
667 cfg.out_chans,
668 cfg.out_chans
669 );
670 let (ln2_g, _) = weights.take("image_encoder.neck.3.weight")?;
671 let (ln2_b, _) = weights.take("image_encoder.neck.3.bias")?;
672 Ok(NeckWeights {
673 conv1_w,
674 ln1_g,
675 ln1_b,
676 conv2_w,
677 ln2_g,
678 ln2_b,
679 embed_dim: cfg.embed_dim,
680 out_chans: cfg.out_chans,
681 eps: cfg.layer_norm_eps as f32,
682 })
683}
684
685pub fn apply_neck_host(neck: &NeckWeights, body_out: &[f32], hw: usize) -> Vec<f32> {
689 let e = neck.embed_dim;
690 let oc = neck.out_chans;
691 let eps = neck.eps;
692
693 let s = hw * hw;
697 let mut feat = vec![0f32; s * oc]; for si in 0..s {
699 for oi in 0..oc {
700 let mut acc = 0f32;
701 for ei in 0..e {
702 acc += body_out[si * e + ei] * neck.conv1_w[oi * e + ei];
703 }
704 feat[si * oc + oi] = acc;
705 }
706 }
707
708 layernorm2d_inplace(&mut feat, s, oc, &neck.ln1_g, &neck.ln1_b, eps);
710
711 let mut nchw = vec![0f32; oc * hw * hw];
714 for y in 0..hw {
715 for x in 0..hw {
716 for c in 0..oc {
717 nchw[c * hw * hw + y * hw + x] = feat[(y * hw + x) * oc + c];
718 }
719 }
720 }
721 let conv2_out = conv2d_3x3_pad1(&nchw, oc, oc, hw, hw, &neck.conv2_w);
722
723 let mut bhwc = vec![0f32; s * oc];
725 for c in 0..oc {
726 for y in 0..hw {
727 for x in 0..hw {
728 bhwc[(y * hw + x) * oc + c] = conv2_out[c * hw * hw + y * hw + x];
729 }
730 }
731 }
732 layernorm2d_inplace(&mut bhwc, s, oc, &neck.ln2_g, &neck.ln2_b, eps);
733
734 let mut out_nchw = vec![0f32; oc * hw * hw];
735 for y in 0..hw {
736 for x in 0..hw {
737 for c in 0..oc {
738 out_nchw[c * hw * hw + y * hw + x] = bhwc[(y * hw + x) * oc + c];
739 }
740 }
741 }
742 out_nchw
743}
744
745fn layernorm2d_inplace(data: &mut [f32], s: usize, c: usize, g: &[f32], b: &[f32], eps: f32) {
747 for si in 0..s {
748 let row = &mut data[si * c..(si + 1) * c];
749 let mean: f32 = row.iter().sum::<f32>() / c as f32;
750 let var: f32 = row.iter().map(|v| (v - mean) * (v - mean)).sum::<f32>() / c as f32;
751 let inv = 1.0 / (var + eps).sqrt();
752 for k in 0..c {
753 row[k] = (row[k] - mean) * inv * g[k] + b[k];
754 }
755 }
756}
757
758fn conv2d_3x3_pad1(
762 input: &[f32],
763 in_c: usize,
764 out_c: usize,
765 h: usize,
766 w: usize,
767 weight: &[f32], ) -> Vec<f32> {
769 let mut out = vec![0f32; out_c * h * w];
770 for oc in 0..out_c {
771 for y in 0..h {
772 for x in 0..w {
773 let mut acc = 0f32;
774 for ic in 0..in_c {
775 for ky in 0..3 {
776 let iy = y as isize + ky as isize - 1;
777 if iy < 0 || iy >= h as isize {
778 continue;
779 }
780 for kx in 0..3 {
781 let ix = x as isize + kx as isize - 1;
782 if ix < 0 || ix >= w as isize {
783 continue;
784 }
785 let v = input[ic * h * w + iy as usize * w + ix as usize];
786 let wi = ((oc * in_c + ic) * 3 + ky) * 3 + kx;
787 acc += v * weight[wi];
788 }
789 }
790 }
791 out[oc * h * w + y * w + x] = acc;
792 }
793 }
794 }
795 out
796}
797
798fn load_p(
801 sb: &mut SamBuilder,
802 weights: &mut WeightMap,
803 key: &str,
804 transpose: bool,
805) -> Result<HirNodeId> {
806 let (data, shape) = if transpose {
807 weights
808 .take_transposed(key)
809 .map_err(|e| anyhow!("transpose-load `{key}`: {e}"))?
810 } else {
811 weights
812 .take(key)
813 .map_err(|e| anyhow!("load `{key}`: {e}"))?
814 };
815 let name = key.to_string();
816 let id = sb.m().param(&name, Shape::new(&shape, DType::F32));
817 sb.params.insert(name, data);
818 Ok(id)
819}
820
821#[allow(dead_code)]
822fn scalar_param(sb: &mut SamBuilder, name: &str, value: f32) -> HirNodeId {
823 let id = sb.m().param(name, Shape::new(&[1], DType::F32));
824 sb.params.insert(name.to_string(), vec![value]);
825 id
826}
827
828fn const_param(sb: &mut SamBuilder, name: &str, shape: &[usize], data: Vec<f32>) -> HirNodeId {
829 let id = sb.m().param(name, Shape::new(shape, DType::F32));
830 sb.params.insert(name.to_string(), data);
831 id
832}
833
834fn pad_zero_param(sb: &mut SamBuilder, name: &str, shape: &[usize]) -> HirNodeId {
835 let n: usize = shape.iter().product();
836 let id = sb.m().param(name, Shape::new(shape, DType::F32));
837 sb.params.insert(name.to_string(), vec![0f32; n]);
838 id
839}