1use egui::{
2 Context, Id, Pos2, Rect, Ui, Vec2,
3 ahash::HashSet,
4 emath::{GuiRounding, TSTransform},
5 style::Spacing,
6};
7use smallvec::{SmallVec, ToSmallVec, smallvec};
8
9use crate::{InPinId, NodeId, OutPinId, Snarl};
10
11use super::{SnarlWidget, transform_matching_points};
12
13pub type RowHeights = SmallVec<[f32; 8]>;
14
15#[derive(Debug)]
17pub struct NodeState {
18 size: Vec2,
21 header_height: f32,
22 input_heights: RowHeights,
23 output_heights: RowHeights,
24
25 id: Id,
26 dirty: bool,
27}
28
29#[derive(Clone, PartialEq)]
30struct NodeData {
31 size: Vec2,
32 header_height: f32,
33 input_heights: RowHeights,
34 output_heights: RowHeights,
35}
36
37impl NodeState {
38 pub fn load(cx: &Context, id: Id, spacing: &Spacing) -> Self {
39 cx.data(|d| d.get_temp::<NodeData>(id)).map_or_else(
40 || {
41 cx.request_discard("NodeState initialization");
42 Self::initial(id, spacing)
43 },
44 |data| NodeState {
45 size: data.size,
46 header_height: data.header_height,
47 input_heights: data.input_heights,
48 output_heights: data.output_heights,
49 id,
50 dirty: false,
51 },
52 )
53 }
54
55 pub fn clear(self, cx: &Context) {
56 cx.data_mut(|d| d.remove::<Self>(self.id));
57 }
58
59 pub fn store(self, cx: &Context) {
60 if self.dirty {
61 cx.data_mut(|d| {
62 d.insert_temp(
63 self.id,
64 NodeData {
65 size: self.size,
66 header_height: self.header_height,
67 input_heights: self.input_heights,
68 output_heights: self.output_heights,
69 },
70 );
71 });
72 cx.request_repaint();
73 }
74 }
75
76 pub fn node_rect(&self, pos: Pos2, openness: f32) -> Rect {
78 Rect::from_min_size(
79 pos,
80 egui::vec2(self.size.x, f32::max(self.header_height, self.size.y * openness)),
81 )
82 .round_ui()
83 }
84
85 pub fn payload_offset(&self, openness: f32) -> f32 {
86 ((self.size.y) * (1.0 - openness)).round_ui()
87 }
88
89 pub fn set_size(&mut self, size: Vec2) {
90 if self.size != size {
91 self.size = size;
92 self.dirty = true;
93 }
94 }
95
96 pub fn header_height(&self) -> f32 {
97 self.header_height.round_ui()
98 }
99
100 pub fn set_header_height(&mut self, height: f32) {
101 #[allow(clippy::float_cmp)]
102 if self.header_height != height {
103 self.header_height = height;
104 self.dirty = true;
105 }
106 }
107
108 pub const fn input_heights(&self) -> &RowHeights {
109 &self.input_heights
110 }
111
112 pub const fn output_heights(&self) -> &RowHeights {
113 &self.output_heights
114 }
115
116 pub fn set_input_heights(&mut self, input_heights: RowHeights) {
117 #[allow(clippy::float_cmp)]
118 if self.input_heights != input_heights {
119 self.input_heights = input_heights;
120 self.dirty = true;
121 }
122 }
123
124 pub fn set_output_heights(&mut self, output_heights: RowHeights) {
125 #[allow(clippy::float_cmp)]
126 if self.output_heights != output_heights {
127 self.output_heights = output_heights;
128 self.dirty = true;
129 }
130 }
131
132 const fn initial(id: Id, spacing: &Spacing) -> Self {
133 NodeState {
134 size: spacing.interact_size,
135 header_height: spacing.interact_size.y,
136 input_heights: SmallVec::new_const(),
137 output_heights: SmallVec::new_const(),
138 id,
139 dirty: true,
140 }
141 }
142}
143
144#[derive(Clone)]
145pub enum NewWires {
146 In(SmallVec<[InPinId; 4]>),
147 Out(SmallVec<[OutPinId; 4]>),
148}
149
150#[derive(Clone, Copy)]
151struct RectSelect {
152 origin: Pos2,
153 current: Pos2,
154}
155
156pub struct SnarlState {
157 to_global: TSTransform,
159
160 new_wires: Option<NewWires>,
161
162 new_wires_menu: bool,
164
165 id: Id,
166
167 dirty: bool,
169
170 rect_selection: Option<RectSelect>,
172
173 draw_order: Vec<NodeId>,
175
176 selected_nodes: SmallVec<[NodeId; 8]>,
178}
179
180#[derive(Clone, Default)]
181struct DrawOrder(Vec<NodeId>);
182
183impl DrawOrder {
184 fn save(self, cx: &Context, id: Id) {
185 cx.data_mut(|d| {
186 if self.0.is_empty() {
187 d.remove_temp::<Self>(id);
188 } else {
189 d.insert_temp::<Self>(id, self);
190 }
191 });
192 }
193
194 fn load(cx: &Context, id: Id) -> Self {
195 cx.data(|d| d.get_temp::<Self>(id)).unwrap_or_default()
196 }
197}
198
199#[derive(Clone, Default)]
200struct SelectedNodes(SmallVec<[NodeId; 8]>);
201
202impl SelectedNodes {
203 fn save(self, cx: &Context, id: Id) {
204 cx.data_mut(|d| {
205 if self.0.is_empty() {
206 d.remove_temp::<Self>(id);
207 } else {
208 d.get_temp_mut_or_default::<Self>(id).clone_from(&self);
209 d.insert_temp::<Self>(id, self);
210 }
211 });
212 }
213
214 fn load(cx: &Context, id: Id) -> Self {
215 cx.data(|d| d.get_temp::<Self>(id)).unwrap_or_default()
216 }
217}
218
219#[derive(Clone)]
220struct SnarlStateData {
221 to_global: TSTransform,
222 new_wires: Option<NewWires>,
223 new_wires_menu: bool,
224 rect_selection: Option<RectSelect>,
225}
226
227impl SnarlStateData {
228 fn save(self, cx: &Context, id: Id) {
229 cx.data_mut(|d| {
230 d.insert_temp(id, self);
231 });
232 }
233
234 fn load(cx: &Context, id: Id) -> Option<Self> {
235 cx.data(|d| d.get_temp(id))
236 }
237}
238
239fn prune_selected_nodes<T>(selected_nodes: &mut SmallVec<[NodeId; 8]>, snarl: &Snarl<T>) -> bool {
240 let old_size = selected_nodes.len();
241 selected_nodes.retain(|node| snarl.nodes.contains(node.0));
242 old_size != selected_nodes.len()
243}
244
245impl SnarlState {
246 pub fn load<T>(
247 cx: &Context,
248 id: Id,
249 snarl: &Snarl<T>,
250 ui_rect: Rect,
251 min_scale: f32,
252 max_scale: f32,
253 ) -> Self {
254 let Some(data) = SnarlStateData::load(cx, id) else {
255 cx.request_discard("Initial placing");
256 return Self::initial(id, snarl, ui_rect, min_scale, max_scale);
257 };
258
259 let mut selected_nodes = SelectedNodes::load(cx, id).0;
260 let dirty = prune_selected_nodes(&mut selected_nodes, snarl);
261
262 let draw_order = DrawOrder::load(cx, id).0;
263
264 SnarlState {
265 to_global: data.to_global,
266 new_wires: data.new_wires,
267 new_wires_menu: data.new_wires_menu,
268 id,
269 dirty,
270 rect_selection: data.rect_selection,
271 draw_order,
272 selected_nodes,
273 }
274 }
275
276 fn initial<T>(id: Id, snarl: &Snarl<T>, ui_rect: Rect, min_scale: f32, max_scale: f32) -> Self {
277 let mut bb = Rect::NOTHING;
278
279 for (_, node) in &snarl.nodes {
280 bb.extend_with(node.pos);
281 }
282
283 if bb.is_finite() {
284 bb = bb.expand(100.0);
285 } else if ui_rect.is_finite() {
286 bb = ui_rect;
287 } else {
288 bb = Rect::from_min_max(Pos2::new(-100.0, -100.0), Pos2::new(100.0, 100.0));
289 }
290
291 let scaling2 = ui_rect.size() / bb.size();
292 let scaling = scaling2.min_elem().clamp(min_scale, max_scale);
293
294 let to_global = transform_matching_points(bb.center(), ui_rect.center(), scaling);
295
296 SnarlState {
297 to_global,
298 new_wires: None,
299 new_wires_menu: false,
300 id,
301 dirty: true,
302 draw_order: Vec::new(),
303 rect_selection: None,
304 selected_nodes: SmallVec::new(),
305 }
306 }
307
308 #[inline(always)]
309 pub fn store<T>(mut self, snarl: &Snarl<T>, cx: &Context) {
310 self.dirty |= prune_selected_nodes(&mut self.selected_nodes, snarl);
311
312 if self.dirty {
313 let data = SnarlStateData {
314 to_global: self.to_global,
315 new_wires: self.new_wires,
316 new_wires_menu: self.new_wires_menu,
317 rect_selection: self.rect_selection,
318 };
319 data.save(cx, self.id);
320
321 DrawOrder(self.draw_order).save(cx, self.id);
322 SelectedNodes(self.selected_nodes).save(cx, self.id);
323
324 cx.request_repaint();
325 }
326 }
327
328 pub const fn to_global(&self) -> TSTransform {
329 self.to_global
330 }
331
332 pub fn set_to_global(&mut self, to_global: TSTransform) {
333 if self.to_global != to_global {
334 self.to_global = to_global;
335 self.dirty = true;
336 }
337 }
338
339 pub fn look_at(&mut self, view: Rect, ui_rect: Rect, min_scale: f32, max_scale: f32) {
340 let scaling2 = ui_rect.size() / view.size();
341 let scaling = scaling2.min_elem().clamp(min_scale, max_scale);
342
343 let to_global = transform_matching_points(view.center(), ui_rect.center(), scaling);
344
345 if self.to_global != to_global {
346 self.to_global = to_global;
347 self.dirty = true;
348 }
349 }
350
351 pub fn start_new_wire_in(&mut self, pin: InPinId) {
352 self.new_wires = Some(NewWires::In(smallvec![pin]));
353 self.new_wires_menu = false;
354 self.dirty = true;
355 }
356
357 pub fn start_new_wire_out(&mut self, pin: OutPinId) {
358 self.new_wires = Some(NewWires::Out(smallvec![pin]));
359 self.new_wires_menu = false;
360 self.dirty = true;
361 }
362
363 pub fn start_new_wires_in(&mut self, pins: &[InPinId]) {
364 self.new_wires = Some(NewWires::In(pins.to_smallvec()));
365 self.new_wires_menu = false;
366 self.dirty = true;
367 }
368
369 pub fn start_new_wires_out(&mut self, pins: &[OutPinId]) {
370 self.new_wires = Some(NewWires::Out(pins.to_smallvec()));
371 self.new_wires_menu = false;
372 self.dirty = true;
373 }
374
375 pub fn add_new_wire_in(&mut self, pin: InPinId) {
376 debug_assert!(!self.new_wires_menu);
377 let Some(NewWires::In(pins)) = &mut self.new_wires else {
378 unreachable!();
379 };
380
381 if !pins.contains(&pin) {
382 pins.push(pin);
383 self.dirty = true;
384 }
385 }
386
387 pub fn add_new_wire_out(&mut self, pin: OutPinId) {
388 debug_assert!(!self.new_wires_menu);
389 let Some(NewWires::Out(pins)) = &mut self.new_wires else {
390 unreachable!();
391 };
392
393 if !pins.contains(&pin) {
394 pins.push(pin);
395 self.dirty = true;
396 }
397 }
398
399 pub fn remove_new_wire_in(&mut self, pin: InPinId) {
400 debug_assert!(!self.new_wires_menu);
401 let Some(NewWires::In(pins)) = &mut self.new_wires else {
402 unreachable!();
403 };
404
405 if let Some(idx) = pins.iter().position(|p| *p == pin) {
406 pins.swap_remove(idx);
407 self.dirty = true;
408 }
409 }
410
411 pub fn remove_new_wire_out(&mut self, pin: OutPinId) {
412 debug_assert!(!self.new_wires_menu);
413 let Some(NewWires::Out(pins)) = &mut self.new_wires else {
414 unreachable!();
415 };
416
417 if let Some(idx) = pins.iter().position(|p| *p == pin) {
418 pins.swap_remove(idx);
419 self.dirty = true;
420 }
421 }
422
423 pub const fn has_new_wires(&self) -> bool {
424 matches!((self.new_wires.as_ref(), self.new_wires_menu), (Some(_), false))
425 }
426
427 pub const fn has_new_wires_in(&self) -> bool {
428 matches!((&self.new_wires, self.new_wires_menu), (Some(NewWires::In(_)), false))
429 }
430
431 pub const fn has_new_wires_out(&self) -> bool {
432 matches!((&self.new_wires, self.new_wires_menu), (Some(NewWires::Out(_)), false))
433 }
434
435 pub const fn new_wires(&self) -> Option<&NewWires> {
436 match (&self.new_wires, self.new_wires_menu) {
437 (Some(new_wires), false) => Some(new_wires),
438 _ => None,
439 }
440 }
441
442 pub const fn take_new_wires(&mut self) -> Option<NewWires> {
443 match (&self.new_wires, self.new_wires_menu) {
444 (Some(_), false) => {
445 self.dirty = true;
446 self.new_wires.take()
447 }
448 _ => None,
449 }
450 }
451
452 pub(crate) const fn take_new_wires_menu(&mut self) -> Option<NewWires> {
453 match (&self.new_wires, self.new_wires_menu) {
454 (Some(_), true) => {
455 self.dirty = true;
456 self.new_wires.take()
457 }
458 _ => None,
459 }
460 }
461
462 pub(crate) fn set_new_wires_menu(&mut self, wires: NewWires) {
463 debug_assert!(self.new_wires.is_none());
464 self.new_wires = Some(wires);
465 self.new_wires_menu = true;
466 }
467
468 pub(crate) fn update_draw_order<T>(&mut self, snarl: &Snarl<T>) -> Vec<NodeId> {
469 let mut node_ids = snarl.nodes.iter().map(|(id, _)| NodeId(id)).collect::<HashSet<_>>();
470
471 self.draw_order.retain(|id| {
472 let has = node_ids.remove(id);
473 self.dirty |= !has;
474 has
475 });
476
477 self.dirty |= !node_ids.is_empty();
478
479 for new_id in node_ids {
480 self.draw_order.push(new_id);
481 }
482
483 self.draw_order.clone()
484 }
485
486 pub(crate) fn node_to_top(&mut self, node: NodeId) {
487 if let Some(order) = self.draw_order.iter().position(|idx| *idx == node) {
488 self.draw_order.remove(order);
489 self.draw_order.push(node);
490 }
491 self.dirty = true;
492 }
493
494 pub fn selected_nodes(&self) -> &[NodeId] {
495 &self.selected_nodes
496 }
497
498 pub fn select_one_node(&mut self, reset: bool, node: NodeId) {
499 if reset {
500 if self.selected_nodes[..] == [node] {
501 return;
502 }
503
504 self.deselect_all_nodes();
505 } else if let Some(pos) = self.selected_nodes.iter().position(|n| *n == node) {
506 if pos == self.selected_nodes.len() - 1 {
507 return;
508 }
509 self.selected_nodes.remove(pos);
510 }
511 self.selected_nodes.push(node);
512 self.dirty = true;
513 }
514
515 pub fn select_many_nodes(&mut self, reset: bool, nodes: impl Iterator<Item = NodeId>) {
516 if reset {
517 self.deselect_all_nodes();
518 self.selected_nodes.extend(nodes);
519 self.dirty = true;
520 } else {
521 nodes.for_each(|node| self.select_one_node(false, node));
522 }
523 }
524
525 pub fn deselect_one_node(&mut self, node: NodeId) {
526 if let Some(pos) = self.selected_nodes.iter().position(|n| *n == node) {
527 self.selected_nodes.remove(pos);
528 self.dirty = true;
529 }
530 }
531
532 pub fn deselect_many_nodes(&mut self, nodes: impl Iterator<Item = NodeId>) {
533 for node in nodes {
534 if let Some(pos) = self.selected_nodes.iter().position(|n| *n == node) {
535 self.selected_nodes.remove(pos);
536 self.dirty = true;
537 }
538 }
539 }
540
541 pub fn deselect_all_nodes(&mut self) {
542 self.dirty |= !self.selected_nodes.is_empty();
543 self.selected_nodes.clear();
544 }
545
546 pub const fn start_rect_selection(&mut self, pos: Pos2) {
547 self.dirty |= self.rect_selection.is_none();
548 self.rect_selection = Some(RectSelect { origin: pos, current: pos });
549 }
550
551 pub const fn stop_rect_selection(&mut self) {
552 self.dirty |= self.rect_selection.is_some();
553 self.rect_selection = None;
554 }
555
556 pub const fn is_rect_selection(&self) -> bool {
557 self.rect_selection.is_some()
558 }
559
560 pub const fn update_rect_selection(&mut self, pos: Pos2) {
561 if let Some(rect_selection) = &mut self.rect_selection {
562 rect_selection.current = pos;
563 self.dirty = true;
564 }
565 }
566
567 pub fn rect_selection(&self) -> Option<Rect> {
568 let rect = self.rect_selection?;
569 Some(Rect::from_two_pos(rect.origin, rect.current))
570 }
571}
572
573impl SnarlWidget {
574 #[must_use]
578 #[inline]
579 pub fn get_selected_nodes(self, ui: &Ui) -> Vec<NodeId> {
580 self.get_selected_nodes_at(ui.id(), ui.ctx())
581 }
582
583 #[must_use]
587 #[inline]
588 pub fn get_selected_nodes_at(self, ui_id: Id, ctx: &Context) -> Vec<NodeId> {
589 let snarl_id = self.get_id(ui_id);
590
591 ctx.data(|d| d.get_temp::<SelectedNodes>(snarl_id).unwrap_or_default().0).into_vec()
592 }
593}
594
595#[must_use]
600#[inline]
601pub fn get_selected_nodes(id: Id, ctx: &Context) -> Vec<NodeId> {
602 ctx.data(|d| d.get_temp::<SelectedNodes>(id).unwrap_or_default().0).into_vec()
603}