egui_treeize/ui/
state.rs

1use egui::{
2  Context, Id, Pos2, Rect, Ui, Vec2,
3  ahash::HashSet,
4  emath::{GuiRounding, TSTransform},
5  style::Spacing,
6};
7use smallvec::{SmallVec, ToSmallVec, smallvec};
8
9use crate::{InPinId, NodeId, OutPinId, Snarl};
10
11use super::{SnarlWidget, transform_matching_points};
12
13pub type RowHeights = SmallVec<[f32; 8]>;
14
15/// Node UI state.
16#[derive(Debug)]
17pub struct NodeState {
18  /// Node size for this frame.
19  /// It is updated to fit content.
20  size: Vec2,
21  header_height: f32,
22  input_heights: RowHeights,
23  output_heights: RowHeights,
24
25  id: Id,
26  dirty: bool,
27}
28
29#[derive(Clone, PartialEq)]
30struct NodeData {
31  size: Vec2,
32  header_height: f32,
33  input_heights: RowHeights,
34  output_heights: RowHeights,
35}
36
37impl NodeState {
38  pub fn load(cx: &Context, id: Id, spacing: &Spacing) -> Self {
39    cx.data(|d| d.get_temp::<NodeData>(id)).map_or_else(
40      || {
41        cx.request_discard("NodeState initialization");
42        Self::initial(id, spacing)
43      },
44      |data| NodeState {
45        size: data.size,
46        header_height: data.header_height,
47        input_heights: data.input_heights,
48        output_heights: data.output_heights,
49        id,
50        dirty: false,
51      },
52    )
53  }
54
55  pub fn clear(self, cx: &Context) {
56    cx.data_mut(|d| d.remove::<Self>(self.id));
57  }
58
59  pub fn store(self, cx: &Context) {
60    if self.dirty {
61      cx.data_mut(|d| {
62        d.insert_temp(
63          self.id,
64          NodeData {
65            size: self.size,
66            header_height: self.header_height,
67            input_heights: self.input_heights,
68            output_heights: self.output_heights,
69          },
70        );
71      });
72      cx.request_repaint();
73    }
74  }
75
76  /// Finds node rect at specific position (excluding node frame margin).
77  pub fn node_rect(&self, pos: Pos2, openness: f32) -> Rect {
78    Rect::from_min_size(
79      pos,
80      egui::vec2(self.size.x, f32::max(self.header_height, self.size.y * openness)),
81    )
82    .round_ui()
83  }
84
85  pub fn payload_offset(&self, openness: f32) -> f32 {
86    ((self.size.y) * (1.0 - openness)).round_ui()
87  }
88
89  pub fn set_size(&mut self, size: Vec2) {
90    if self.size != size {
91      self.size = size;
92      self.dirty = true;
93    }
94  }
95
96  pub fn header_height(&self) -> f32 {
97    self.header_height.round_ui()
98  }
99
100  pub fn set_header_height(&mut self, height: f32) {
101    #[allow(clippy::float_cmp)]
102    if self.header_height != height {
103      self.header_height = height;
104      self.dirty = true;
105    }
106  }
107
108  pub const fn input_heights(&self) -> &RowHeights {
109    &self.input_heights
110  }
111
112  pub const fn output_heights(&self) -> &RowHeights {
113    &self.output_heights
114  }
115
116  pub fn set_input_heights(&mut self, input_heights: RowHeights) {
117    #[allow(clippy::float_cmp)]
118    if self.input_heights != input_heights {
119      self.input_heights = input_heights;
120      self.dirty = true;
121    }
122  }
123
124  pub fn set_output_heights(&mut self, output_heights: RowHeights) {
125    #[allow(clippy::float_cmp)]
126    if self.output_heights != output_heights {
127      self.output_heights = output_heights;
128      self.dirty = true;
129    }
130  }
131
132  const fn initial(id: Id, spacing: &Spacing) -> Self {
133    NodeState {
134      size: spacing.interact_size,
135      header_height: spacing.interact_size.y,
136      input_heights: SmallVec::new_const(),
137      output_heights: SmallVec::new_const(),
138      id,
139      dirty: true,
140    }
141  }
142}
143
144#[derive(Clone)]
145pub enum NewWires {
146  In(SmallVec<[InPinId; 4]>),
147  Out(SmallVec<[OutPinId; 4]>),
148}
149
150#[derive(Clone, Copy)]
151struct RectSelect {
152  origin: Pos2,
153  current: Pos2,
154}
155
156pub struct SnarlState {
157  /// Snarl viewport transform to global space.
158  to_global: TSTransform,
159
160  new_wires: Option<NewWires>,
161
162  /// Flag indicating that new wires are owned by the menu now.
163  new_wires_menu: bool,
164
165  id: Id,
166
167  /// Flag indicating that the graph state is dirty must be saved.
168  dirty: bool,
169
170  /// Active rect selection.
171  rect_selection: Option<RectSelect>,
172
173  /// Order of nodes to draw.
174  draw_order: Vec<NodeId>,
175
176  /// List of currently selected nodes.
177  selected_nodes: SmallVec<[NodeId; 8]>,
178}
179
180#[derive(Clone, Default)]
181struct DrawOrder(Vec<NodeId>);
182
183impl DrawOrder {
184  fn save(self, cx: &Context, id: Id) {
185    cx.data_mut(|d| {
186      if self.0.is_empty() {
187        d.remove_temp::<Self>(id);
188      } else {
189        d.insert_temp::<Self>(id, self);
190      }
191    });
192  }
193
194  fn load(cx: &Context, id: Id) -> Self {
195    cx.data(|d| d.get_temp::<Self>(id)).unwrap_or_default()
196  }
197}
198
199#[derive(Clone, Default)]
200struct SelectedNodes(SmallVec<[NodeId; 8]>);
201
202impl SelectedNodes {
203  fn save(self, cx: &Context, id: Id) {
204    cx.data_mut(|d| {
205      if self.0.is_empty() {
206        d.remove_temp::<Self>(id);
207      } else {
208        d.get_temp_mut_or_default::<Self>(id).clone_from(&self);
209        d.insert_temp::<Self>(id, self);
210      }
211    });
212  }
213
214  fn load(cx: &Context, id: Id) -> Self {
215    cx.data(|d| d.get_temp::<Self>(id)).unwrap_or_default()
216  }
217}
218
219#[derive(Clone)]
220struct SnarlStateData {
221  to_global: TSTransform,
222  new_wires: Option<NewWires>,
223  new_wires_menu: bool,
224  rect_selection: Option<RectSelect>,
225}
226
227impl SnarlStateData {
228  fn save(self, cx: &Context, id: Id) {
229    cx.data_mut(|d| {
230      d.insert_temp(id, self);
231    });
232  }
233
234  fn load(cx: &Context, id: Id) -> Option<Self> {
235    cx.data(|d| d.get_temp(id))
236  }
237}
238
239fn prune_selected_nodes<T>(selected_nodes: &mut SmallVec<[NodeId; 8]>, snarl: &Snarl<T>) -> bool {
240  let old_size = selected_nodes.len();
241  selected_nodes.retain(|node| snarl.nodes.contains(node.0));
242  old_size != selected_nodes.len()
243}
244
245impl SnarlState {
246  pub fn load<T>(
247    cx: &Context,
248    id: Id,
249    snarl: &Snarl<T>,
250    ui_rect: Rect,
251    min_scale: f32,
252    max_scale: f32,
253  ) -> Self {
254    let Some(data) = SnarlStateData::load(cx, id) else {
255      cx.request_discard("Initial placing");
256      return Self::initial(id, snarl, ui_rect, min_scale, max_scale);
257    };
258
259    let mut selected_nodes = SelectedNodes::load(cx, id).0;
260    let dirty = prune_selected_nodes(&mut selected_nodes, snarl);
261
262    let draw_order = DrawOrder::load(cx, id).0;
263
264    SnarlState {
265      to_global: data.to_global,
266      new_wires: data.new_wires,
267      new_wires_menu: data.new_wires_menu,
268      id,
269      dirty,
270      rect_selection: data.rect_selection,
271      draw_order,
272      selected_nodes,
273    }
274  }
275
276  fn initial<T>(id: Id, snarl: &Snarl<T>, ui_rect: Rect, min_scale: f32, max_scale: f32) -> Self {
277    let mut bb = Rect::NOTHING;
278
279    for (_, node) in &snarl.nodes {
280      bb.extend_with(node.pos);
281    }
282
283    if bb.is_finite() {
284      bb = bb.expand(100.0);
285    } else if ui_rect.is_finite() {
286      bb = ui_rect;
287    } else {
288      bb = Rect::from_min_max(Pos2::new(-100.0, -100.0), Pos2::new(100.0, 100.0));
289    }
290
291    let scaling2 = ui_rect.size() / bb.size();
292    let scaling = scaling2.min_elem().clamp(min_scale, max_scale);
293
294    let to_global = transform_matching_points(bb.center(), ui_rect.center(), scaling);
295
296    SnarlState {
297      to_global,
298      new_wires: None,
299      new_wires_menu: false,
300      id,
301      dirty: true,
302      draw_order: Vec::new(),
303      rect_selection: None,
304      selected_nodes: SmallVec::new(),
305    }
306  }
307
308  #[inline(always)]
309  pub fn store<T>(mut self, snarl: &Snarl<T>, cx: &Context) {
310    self.dirty |= prune_selected_nodes(&mut self.selected_nodes, snarl);
311
312    if self.dirty {
313      let data = SnarlStateData {
314        to_global: self.to_global,
315        new_wires: self.new_wires,
316        new_wires_menu: self.new_wires_menu,
317        rect_selection: self.rect_selection,
318      };
319      data.save(cx, self.id);
320
321      DrawOrder(self.draw_order).save(cx, self.id);
322      SelectedNodes(self.selected_nodes).save(cx, self.id);
323
324      cx.request_repaint();
325    }
326  }
327
328  pub const fn to_global(&self) -> TSTransform {
329    self.to_global
330  }
331
332  pub fn set_to_global(&mut self, to_global: TSTransform) {
333    if self.to_global != to_global {
334      self.to_global = to_global;
335      self.dirty = true;
336    }
337  }
338
339  pub fn look_at(&mut self, view: Rect, ui_rect: Rect, min_scale: f32, max_scale: f32) {
340    let scaling2 = ui_rect.size() / view.size();
341    let scaling = scaling2.min_elem().clamp(min_scale, max_scale);
342
343    let to_global = transform_matching_points(view.center(), ui_rect.center(), scaling);
344
345    if self.to_global != to_global {
346      self.to_global = to_global;
347      self.dirty = true;
348    }
349  }
350
351  pub fn start_new_wire_in(&mut self, pin: InPinId) {
352    self.new_wires = Some(NewWires::In(smallvec![pin]));
353    self.new_wires_menu = false;
354    self.dirty = true;
355  }
356
357  pub fn start_new_wire_out(&mut self, pin: OutPinId) {
358    self.new_wires = Some(NewWires::Out(smallvec![pin]));
359    self.new_wires_menu = false;
360    self.dirty = true;
361  }
362
363  pub fn start_new_wires_in(&mut self, pins: &[InPinId]) {
364    self.new_wires = Some(NewWires::In(pins.to_smallvec()));
365    self.new_wires_menu = false;
366    self.dirty = true;
367  }
368
369  pub fn start_new_wires_out(&mut self, pins: &[OutPinId]) {
370    self.new_wires = Some(NewWires::Out(pins.to_smallvec()));
371    self.new_wires_menu = false;
372    self.dirty = true;
373  }
374
375  pub fn add_new_wire_in(&mut self, pin: InPinId) {
376    debug_assert!(!self.new_wires_menu);
377    let Some(NewWires::In(pins)) = &mut self.new_wires else {
378      unreachable!();
379    };
380
381    if !pins.contains(&pin) {
382      pins.push(pin);
383      self.dirty = true;
384    }
385  }
386
387  pub fn add_new_wire_out(&mut self, pin: OutPinId) {
388    debug_assert!(!self.new_wires_menu);
389    let Some(NewWires::Out(pins)) = &mut self.new_wires else {
390      unreachable!();
391    };
392
393    if !pins.contains(&pin) {
394      pins.push(pin);
395      self.dirty = true;
396    }
397  }
398
399  pub fn remove_new_wire_in(&mut self, pin: InPinId) {
400    debug_assert!(!self.new_wires_menu);
401    let Some(NewWires::In(pins)) = &mut self.new_wires else {
402      unreachable!();
403    };
404
405    if let Some(idx) = pins.iter().position(|p| *p == pin) {
406      pins.swap_remove(idx);
407      self.dirty = true;
408    }
409  }
410
411  pub fn remove_new_wire_out(&mut self, pin: OutPinId) {
412    debug_assert!(!self.new_wires_menu);
413    let Some(NewWires::Out(pins)) = &mut self.new_wires else {
414      unreachable!();
415    };
416
417    if let Some(idx) = pins.iter().position(|p| *p == pin) {
418      pins.swap_remove(idx);
419      self.dirty = true;
420    }
421  }
422
423  pub const fn has_new_wires(&self) -> bool {
424    matches!((self.new_wires.as_ref(), self.new_wires_menu), (Some(_), false))
425  }
426
427  pub const fn has_new_wires_in(&self) -> bool {
428    matches!((&self.new_wires, self.new_wires_menu), (Some(NewWires::In(_)), false))
429  }
430
431  pub const fn has_new_wires_out(&self) -> bool {
432    matches!((&self.new_wires, self.new_wires_menu), (Some(NewWires::Out(_)), false))
433  }
434
435  pub const fn new_wires(&self) -> Option<&NewWires> {
436    match (&self.new_wires, self.new_wires_menu) {
437      (Some(new_wires), false) => Some(new_wires),
438      _ => None,
439    }
440  }
441
442  pub const fn take_new_wires(&mut self) -> Option<NewWires> {
443    match (&self.new_wires, self.new_wires_menu) {
444      (Some(_), false) => {
445        self.dirty = true;
446        self.new_wires.take()
447      }
448      _ => None,
449    }
450  }
451
452  pub(crate) const fn take_new_wires_menu(&mut self) -> Option<NewWires> {
453    match (&self.new_wires, self.new_wires_menu) {
454      (Some(_), true) => {
455        self.dirty = true;
456        self.new_wires.take()
457      }
458      _ => None,
459    }
460  }
461
462  pub(crate) fn set_new_wires_menu(&mut self, wires: NewWires) {
463    debug_assert!(self.new_wires.is_none());
464    self.new_wires = Some(wires);
465    self.new_wires_menu = true;
466  }
467
468  pub(crate) fn update_draw_order<T>(&mut self, snarl: &Snarl<T>) -> Vec<NodeId> {
469    let mut node_ids = snarl.nodes.iter().map(|(id, _)| NodeId(id)).collect::<HashSet<_>>();
470
471    self.draw_order.retain(|id| {
472      let has = node_ids.remove(id);
473      self.dirty |= !has;
474      has
475    });
476
477    self.dirty |= !node_ids.is_empty();
478
479    for new_id in node_ids {
480      self.draw_order.push(new_id);
481    }
482
483    self.draw_order.clone()
484  }
485
486  pub(crate) fn node_to_top(&mut self, node: NodeId) {
487    if let Some(order) = self.draw_order.iter().position(|idx| *idx == node) {
488      self.draw_order.remove(order);
489      self.draw_order.push(node);
490    }
491    self.dirty = true;
492  }
493
494  pub fn selected_nodes(&self) -> &[NodeId] {
495    &self.selected_nodes
496  }
497
498  pub fn select_one_node(&mut self, reset: bool, node: NodeId) {
499    if reset {
500      if self.selected_nodes[..] == [node] {
501        return;
502      }
503
504      self.deselect_all_nodes();
505    } else if let Some(pos) = self.selected_nodes.iter().position(|n| *n == node) {
506      if pos == self.selected_nodes.len() - 1 {
507        return;
508      }
509      self.selected_nodes.remove(pos);
510    }
511    self.selected_nodes.push(node);
512    self.dirty = true;
513  }
514
515  pub fn select_many_nodes(&mut self, reset: bool, nodes: impl Iterator<Item = NodeId>) {
516    if reset {
517      self.deselect_all_nodes();
518      self.selected_nodes.extend(nodes);
519      self.dirty = true;
520    } else {
521      nodes.for_each(|node| self.select_one_node(false, node));
522    }
523  }
524
525  pub fn deselect_one_node(&mut self, node: NodeId) {
526    if let Some(pos) = self.selected_nodes.iter().position(|n| *n == node) {
527      self.selected_nodes.remove(pos);
528      self.dirty = true;
529    }
530  }
531
532  pub fn deselect_many_nodes(&mut self, nodes: impl Iterator<Item = NodeId>) {
533    for node in nodes {
534      if let Some(pos) = self.selected_nodes.iter().position(|n| *n == node) {
535        self.selected_nodes.remove(pos);
536        self.dirty = true;
537      }
538    }
539  }
540
541  pub fn deselect_all_nodes(&mut self) {
542    self.dirty |= !self.selected_nodes.is_empty();
543    self.selected_nodes.clear();
544  }
545
546  pub const fn start_rect_selection(&mut self, pos: Pos2) {
547    self.dirty |= self.rect_selection.is_none();
548    self.rect_selection = Some(RectSelect { origin: pos, current: pos });
549  }
550
551  pub const fn stop_rect_selection(&mut self) {
552    self.dirty |= self.rect_selection.is_some();
553    self.rect_selection = None;
554  }
555
556  pub const fn is_rect_selection(&self) -> bool {
557    self.rect_selection.is_some()
558  }
559
560  pub const fn update_rect_selection(&mut self, pos: Pos2) {
561    if let Some(rect_selection) = &mut self.rect_selection {
562      rect_selection.current = pos;
563      self.dirty = true;
564    }
565  }
566
567  pub fn rect_selection(&self) -> Option<Rect> {
568    let rect = self.rect_selection?;
569    Some(Rect::from_two_pos(rect.origin, rect.current))
570  }
571}
572
573impl SnarlWidget {
574  /// Returns list of nodes selected in the UI for the `SnarlWidget` with same id.
575  ///
576  /// Use same `Ui` instance that was used in [`SnarlWidget::show`].
577  #[must_use]
578  #[inline]
579  pub fn get_selected_nodes(self, ui: &Ui) -> Vec<NodeId> {
580    self.get_selected_nodes_at(ui.id(), ui.ctx())
581  }
582
583  /// Returns list of nodes selected in the UI for the `SnarlWidget` with same id.
584  ///
585  /// `ui_id` must be the Id of the `Ui` instance that was used in [`SnarlWidget::show`].
586  #[must_use]
587  #[inline]
588  pub fn get_selected_nodes_at(self, ui_id: Id, ctx: &Context) -> Vec<NodeId> {
589    let snarl_id = self.get_id(ui_id);
590
591    ctx.data(|d| d.get_temp::<SelectedNodes>(snarl_id).unwrap_or_default().0).into_vec()
592  }
593}
594
595/// Returns nodes selected in the UI for the `SnarlWidget` with same ID.
596///
597/// Only works if [`SnarlWidget::id`] was used.
598/// For other cases construct [`SnarlWidget`] and use [`SnarlWidget::get_selected_nodes`] or [`SnarlWidget::get_selected_nodes_at`].
599#[must_use]
600#[inline]
601pub fn get_selected_nodes(id: Id, ctx: &Context) -> Vec<NodeId> {
602  ctx.data(|d| d.get_temp::<SelectedNodes>(id).unwrap_or_default().0).into_vec()
603}