flag_algebra/
operator.rs

1//! Computing and stocking operators of the flag algebra.
2
3use crate::algebra::QFlag;
4use crate::combinatorics::*;
5use crate::density::*;
6use crate::expr::CoefficientFn;
7use crate::expr::Expr;
8use crate::expr::IndicatorFn;
9use crate::flag::Flag;
10use log::*;
11use ndarray::*;
12use num::*;
13use serde::de::DeserializeOwned;
14use serde::Serialize;
15use sprs::CsMat;
16use std::error::Error;
17use std::fmt;
18use std::fmt::{Display, Formatter};
19use std::fs;
20use std::fs::File;
21use std::io::{BufReader, BufWriter};
22use std::marker::PhantomData;
23use std::ops::*;
24use std::path::*;
25use std::rc::Rc;
26
27/// A trait for flag operators that
28/// can be saved in a file once computed
29/// for the first time.
30///
31/// `A` is the type of the stored object.
32/// The operator operate on flags of type `F`.
33pub trait Savable<A, F>
34where
35    A: Serialize + DeserializeOwned,
36    F: Flag,
37{
38    /// Name of the file where the operator can be saved.
39    fn filename(&self) -> String;
40    /// Compute the object.
41    fn create(&self) -> A;
42    /// Path to the corresponding file.
43    fn file_path(&self) -> PathBuf {
44        let mut filename = PathBuf::from("./data");
45        filename.push(Path::new(F::NAME));
46        filename.push(self.filename());
47        let _ = filename.set_extension("dat");
48        filename
49    }
50    /// (Re)create the object, save it in the corresponding file and return it.
51    fn create_and_save(&self, path: &Path) -> A {
52        info!("Creating {}", path.display());
53        let value = self.create();
54        let file = File::create(path).unwrap();
55        let buf = BufWriter::new(file);
56        let compressed_buf = flate2::write::GzEncoder::new(buf, Default::default());
57        bincode::serialize_into(compressed_buf, &value).unwrap();
58        value
59    }
60    /// Load the object if the file exists and is valid.
61    fn load(&self, path: &Path) -> Result<A, Box<dyn Error>> {
62        let file = File::open(path)?;
63        let buf = BufReader::new(file);
64        let decompressed_buf = flate2::bufread::GzDecoder::new(buf);
65        let data = bincode::deserialize_from(decompressed_buf)?;
66        Ok(data)
67    }
68    /// Function to automatically load the object if the file exists and
69    /// is valid, or create and save it otherwise.
70    fn get(&self) -> A {
71        let path = self.file_path();
72        if path.exists() {
73            debug!("Loading {}", path.display());
74            match self.load(&path) {
75                Ok(v) => {
76                    trace!("Done");
77                    v
78                }
79                Err(e) => {
80                    error!("Failed to load {}: {}", path.display(), e);
81                    self.create_and_save(&path)
82                }
83            }
84        } else {
85            let dir = path.parent().unwrap();
86            match fs::create_dir_all(dir) {
87                Ok(()) => self.create_and_save(&path),
88                Err(e) => {
89                    eprintln!("Cannot create {}.", dir.display());
90                    panic!("{}", e);
91                }
92            }
93        }
94    }
95}
96
97/// Type (or root) of a flag.
98/// It is identified by its size and its id in the list of flags of that size.
99#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Clone)]
100pub struct Type<F: Flag> {
101    /// Size of the type.
102    pub size: usize,
103    /// Index of the type in the list of unlabeled flags of this size.
104    pub id: usize,
105    /// Retains the kind of flag (graph, ...)
106    phantom: PhantomData<F>,
107}
108
109impl<F: Flag> Type<F> {
110    /// Constructor for the type.
111    pub fn new(size: usize, id: usize) -> Self {
112        Self {
113            size,
114            id,
115            phantom: PhantomData,
116        }
117    }
118    /// Create a type of size 0.
119    pub fn empty() -> Self {
120        Self::new(0, 0)
121    }
122    /// Return wether the input has size zero.
123    pub fn is_empty(self) -> bool {
124        self == Self::empty()
125    }
126    /// Write a string that identifies the type.
127    fn to_string_suffix(self) -> String {
128        if self.is_empty() {
129            String::new()
130        } else {
131            format!("_type_{}_id_{}", self.size, self.id)
132        }
133    }
134    /// Create the type corresponding to g
135    pub fn from_flag<G>(g: &G) -> Self
136    where
137        F: Flag,
138        G: Into<F> + Clone,
139    {
140        let f: F = g.clone().into();
141        let size = f.size();
142        let reduced_f = f.canonical();
143        let id = Basis::new(size)
144            .get()
145            .binary_search(&reduced_f)
146            .expect("Flag not found");
147        Self::new(size, id)
148    }
149    /// Iterate on all types of a given size.
150    pub fn types_with_size(size: usize) -> impl Iterator<Item = Self>
151    where
152        F: Flag,
153    {
154        let n_types = Basis::<F>::new(size).get().len();
155        (0..n_types).map(move |id| Self::new(size, id))
156    }
157    /// Print the type identifier in a short way.
158    pub fn print_concise(self) -> String {
159        if self.is_empty() {
160            String::new()
161        } else {
162            format!("{},id{}", self.size, self.id)
163        }
164    }
165}
166
167//============ Basis ===========
168
169/// Identifier for the set of flags with given size and type
170/// (in the sense of a labeled subgraph).
171///
172/// The kind of flag is determined by the associated Rust datatype.
173/// For instance `Basis<Graph>` is the Rust type for a basis of graphs.
174///
175/// `basis.get()` returns an ordered vector containing all corresponding flags.
176#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Clone)]
177pub struct Basis<F: Flag> {
178    /// Number of vertices in the flags of the basis.
179    pub size: usize,
180    /// Type of the flags of the basis.
181    pub t: Type<F>,
182}
183
184/// # Defining a Basis
185impl<F: Flag> Basis<F> {
186    /// Constructor for a basis.
187    fn make(size: usize, t: Type<F>) -> Self {
188        assert!(t.size <= size);
189        Self { size, t }
190    }
191    /// Basis of flag with `size` vertices and without type.
192    ///```
193    /// use flag_algebra::*;
194    /// use flag_algebra::flags::Graph;
195    ///
196    /// // Set of graphs of size 3
197    /// // (the kind of flag -Graph- is deduced by type inference)
198    /// let basis = Basis::new(3);
199    /// let size_3_graphs: Vec<Graph> = basis.get();
200    /// assert_eq!(size_3_graphs.len(), 4);
201    ///
202    /// // With explicit type annotation
203    /// let same_basis: Basis<Graph> = Basis::new(3);
204    ///```
205    pub fn new(size: usize) -> Self {
206        Self::make(size, Type::empty())
207    }
208    /// Basis of flag with same size as `self` and type `t`.
209    ///```
210    /// use flag_algebra::*;
211    /// use flag_algebra::flags::Graph;
212    ///
213    /// // Basis of graphs of size 4 rooted on an edge
214    /// let edge = Graph::new(2, &[(0, 1)]);
215    /// let t = Type::from_flag(&edge);
216    /// let basis: Basis<Graph> = Basis::new(4).with_type(t);
217    ///```
218    pub fn with_type(self, t: Type<F>) -> Self {
219        Self::make(self.size, t)
220    }
221    /// Basis of flag with same size as `self` without type `t`.
222    pub fn without_type(self) -> Self {
223        self.with_type(Type::empty())
224    }
225    /// Basis of flag with `size` vertices and same type as `self`.
226    pub fn with_size(self, size: usize) -> Self {
227        Self::make(size, self.t)
228    }
229    /// Print the basis information in a short way.
230    pub fn print_concise(self) -> String {
231        if self.t == Type::empty() {
232            format!("{}", self.size)
233        } else {
234            format!("{},{},id{}", self.size, self.t.size, self.t.id)
235        }
236    }
237}
238
239impl<F: Flag> Display for Type<F> {
240    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
241        if *self == Type::empty() {
242            write!(f, "Empty type")
243        } else {
244            write!(f, "Type of size {} (id {})", self.size, self.id)
245        }
246    }
247}
248
249impl<F: Flag> Display for Basis<F> {
250    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
251        if self.t == Type::empty() {
252            write!(f, "Flags of size {} without type", self.size)
253        } else {
254            write!(f, "Flags of size {} with type {})", self.size, self.t)
255        }
256    }
257}
258
259impl<F: Flag> Mul for Basis<F> {
260    type Output = Self;
261
262    fn mul(self, rhs: Self) -> Self {
263        assert_eq!(self.t, rhs.t);
264        self.with_size(self.size + rhs.size - self.t.size)
265    }
266}
267
268impl<F: Flag> Div for Basis<F> {
269    type Output = Self;
270
271    fn div(self, rhs: Self) -> Self {
272        assert_eq!(self.t, rhs.t);
273        assert!(self.size >= rhs.size);
274        self.with_size(self.size - rhs.size + self.t.size)
275    }
276}
277
278impl<F: Flag> Savable<Vec<F>, F> for Basis<F> {
279    fn filename(&self) -> String {
280        format!("flags_{}{}", self.size, self.t.to_string_suffix())
281    }
282    fn create(&self) -> Vec<F> {
283        if self.t == Type::empty() {
284            if self.size == 0 {
285                F::size_zero_flags()
286            } else {
287                F::generate_next(&self.with_size(self.size - 1).get())
288            }
289        } else {
290            let type_basis = Self::new(self.t.size).get();
291            let type_flag = &type_basis[self.t.id];
292            F::generate_typed_up(type_flag, &self.without_type().get())
293        }
294    }
295}
296
297//================ Split count
298/// Operator used for flag multiplication
299/// .get() returns a vector of matrices M
300/// where M[i][j, k] is the number of ways to split
301/// i into j and k
302#[derive(Debug, Clone)]
303pub struct SplitCount<F: Flag> {
304    left_size: usize,
305    right_size: usize,
306    type_: Type<F>,
307}
308
309impl<F: Flag> SplitCount<F> {
310    pub fn make(left_size: usize, right_size: usize, type_: Type<F>) -> Self {
311        assert!(type_.size <= left_size);
312        assert!(type_.size <= right_size);
313        Self {
314            left_size,
315            right_size,
316            type_,
317        }
318    }
319
320    pub fn from_input(left: &Basis<F>, right: &Basis<F>) -> Self {
321        assert_eq!(left.t, right.t);
322        Self::make(left.size, right.size, left.t)
323    }
324
325    pub fn left_basis(&self) -> Basis<F> {
326        Basis::make(self.left_size, self.type_)
327    }
328
329    pub fn right_basis(&self) -> Basis<F> {
330        Basis::make(self.right_size, self.type_)
331    }
332
333    fn output_basis(&self) -> Basis<F> {
334        Basis::make(
335            self.right_size + self.left_size - self.type_.size,
336            self.type_,
337        )
338    }
339    pub fn denom(&self) -> u32 {
340        let left_choice = (self.left_size - self.type_.size) as u32;
341        let right_choice = (self.right_size - self.type_.size) as u32;
342        binomial(left_choice, left_choice + right_choice)
343    }
344}
345
346impl<F: Flag> Savable<Vec<CsMat<u32>>, F> for SplitCount<F> {
347    fn filename(&self) -> String {
348        format!(
349            "split_{}_{}{}",
350            self.left_size,
351            self.right_size,
352            self.type_.to_string_suffix()
353        )
354    }
355    fn create(&self) -> Vec<CsMat<u32>> {
356        let left = self.left_basis().get();
357        let right = self.right_basis().get();
358        let target = self.output_basis().get();
359        count_split_tabulate(self.type_.size, &left, &right, &target)
360    }
361}
362
363//================ Subflag count
364
365/// .get() gives a matrix M where M\[i,j\] is the number of copies of i in j
366#[derive(Clone, Debug)]
367pub struct SubflagCount<F: Flag> {
368    k: usize,
369    n: usize,
370    type_: Type<F>,
371}
372
373impl<F: Flag> SubflagCount<F> {
374    pub fn make(k: usize, n: usize, type_: Type<F>) -> Self {
375        assert!(type_.size <= k);
376        assert!(k <= n);
377        Self { k, n, type_ }
378    }
379
380    pub fn from_to(inner: Basis<F>, outer: Basis<F>) -> Self {
381        assert_eq!(inner.t, outer.t);
382        Self::make(inner.size, outer.size, inner.t)
383    }
384
385    pub fn inner_basis(&self) -> Basis<F> {
386        Basis::make(self.k, self.type_)
387    }
388
389    fn outer_basis(&self) -> Basis<F> {
390        Basis::make(self.n, self.type_)
391    }
392
393    pub fn denom(&self) -> u32 {
394        let choices = (self.k - self.type_.size) as u32;
395        let total = (self.n - self.type_.size) as u32;
396        binomial(choices, total)
397    }
398}
399
400impl<F: Flag> Savable<CsMat<u32>, F> for SubflagCount<F> {
401    fn filename(&self) -> String {
402        format!(
403            "subflag_{}_to_{}{}",
404            self.n,
405            self.k,
406            self.type_.to_string_suffix()
407        )
408    }
409    fn create(&self) -> CsMat<u32> {
410        let inner = self.inner_basis().get();
411        let outer = self.outer_basis().get();
412        count_subflag_tabulate(self.type_.size, &inner, &outer)
413    }
414}
415
416// == unlabeling operators
417/// Let F be the flag indexed by id on basis basis
418/// this represents the unlabeling opearation that
419/// sends the type `fully_typed(F)` to the flag `F`
420#[derive(Debug, Clone)]
421pub struct Unlabeling<F: Flag> {
422    pub flag: usize,
423    pub basis: Basis<F>,
424}
425
426impl<F: Flag> Unlabeling<F> {
427    pub fn new(basis: Basis<F>, flag: usize) -> Self {
428        Self { flag, basis }
429    }
430    pub fn total(t: Type<F>) -> Self {
431        Self::new(Basis::new(t.size), t.id)
432    }
433
434    pub fn input_type(&self) -> Type<F>
435    where
436        F: Flag,
437    {
438        if self.basis.t == Type::empty() {
439            Type::new(self.basis.size, self.flag) // !!! Do we assume something ?
440        } else {
441            let basis = self.basis.get();
442            let unlab_basis = self.basis.without_type().get();
443            let flag = &basis[self.flag];
444            let unlab_id = unlab_basis.binary_search(&flag.canonical()).unwrap();
445            Type::new(self.basis.size, unlab_id)
446        }
447    }
448
449    pub fn output_type(&self) -> Type<F> {
450        self.basis.t
451    }
452    /// Return the eta function of Razborov corresponding to
453    /// the untyping operator.
454    pub fn eta(&self) -> Vec<usize>
455    where
456        F: Flag,
457    {
458        let flag = &self.basis.get()[self.flag];
459        let mut morphism = flag.morphism_to_canonical();
460        morphism.truncate(self.basis.t.size);
461        morphism
462    }
463}
464
465#[derive(Debug, Clone)]
466pub struct Unlabel<F: Flag> {
467    pub unlabeling: Unlabeling<F>,
468    pub size: usize,
469}
470
471impl<F: Flag> Savable<(Vec<usize>, Vec<u32>), F> for Unlabel<F> {
472    fn filename(&self) -> String {
473        format!(
474            "unlabel_{}_id_{}_basis_{}{}",
475            self.size,
476            self.unlabeling.flag,
477            self.unlabeling.basis.size,
478            self.unlabeling.basis.t.to_string_suffix()
479        )
480    }
481    fn create(&self) -> (Vec<usize>, Vec<u32>) {
482        let in_basis = Basis::<F>::make(self.size, self.unlabeling.input_type()).get();
483        let out_basis = Basis::<F>::make(self.size, self.unlabeling.output_type()).get();
484        let eta = self.unlabeling.eta();
485        (
486            unlabeling_tabulate(&eta, &in_basis, &out_basis),
487            unlabeling_count_tabulate(&eta, self.unlabeling.basis.size, &in_basis),
488        )
489    }
490}
491
492impl<F: Flag> Unlabel<F> {
493    pub fn denom(&self) -> u32 {
494        let new_type_size = self.unlabeling.basis.t.size;
495        let old_type_size = self.unlabeling.basis.size;
496        let choices = (old_type_size - new_type_size) as u32;
497        let free_vertices = (self.size - new_type_size) as u32;
498        product(free_vertices + 1 - choices, free_vertices)
499    }
500    pub fn output_basis(&self) -> Basis<F>
501    where
502        F: Flag,
503    {
504        Basis::new(self.size).with_type(self.unlabeling.output_type())
505    }
506    pub fn total(b: Basis<F>) -> Self {
507        Self {
508            unlabeling: Unlabeling::total(b.t),
509            size: b.size,
510        }
511    }
512}
513
514#[derive(Debug, Clone)]
515pub struct MulAndUnlabel<F: Flag> {
516    pub split: SplitCount<F>,
517    pub unlabeling: Unlabeling<F>,
518}
519
520impl<F: Flag> MulAndUnlabel<F> {
521    pub fn invariant_classes(&self) -> InvariantClasses<F> {
522        assert_eq!(self.split.left_size, self.split.right_size);
523        let size = self.split.left_size;
524        InvariantClasses(Unlabel {
525            size,
526            unlabeling: self.unlabeling,
527        })
528    }
529    pub fn reduced(&self) -> ReducedByInvariant<F> {
530        ReducedByInvariant(*self)
531    }
532}
533
534impl<F: Flag> Display for MulAndUnlabel<F> {
535    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
536        write!(
537            f,
538            "Mul. and unlabel: {}x{}; {} -> {} (id {})",
539            self.split.left_size,
540            self.split.right_size,
541            self.split.type_,
542            self.unlabeling.basis.t,
543            self.unlabeling.flag
544        )
545    }
546}
547
548impl<F: Flag> MulAndUnlabel<F> {
549    fn unlabel(&self) -> Unlabel<F> {
550        Unlabel {
551            unlabeling: self.unlabeling,
552            size: self.split.output_basis().size,
553        }
554    }
555    pub fn output_basis(&self) -> Basis<F> {
556        Basis::new(self.split.output_basis().size).with_type(self.unlabeling.output_type())
557    }
558    pub fn denom(&self) -> u32 {
559        self.split.denom() * self.unlabel().denom()
560    }
561}
562
563impl<F: Flag> Savable<Vec<CsMat<i64>>, F> for MulAndUnlabel<F> {
564    fn filename(&self) -> String {
565        format!(
566            "{}_then_unlab_id_{}{}",
567            self.split.filename(),
568            self.unlabeling.flag,
569            self.unlabeling.basis.t.to_string_suffix()
570        )
571    }
572    fn create(&self) -> Vec<CsMat<i64>> {
573        let (unlab_f, unlab_c) = self.unlabel().get();
574        let mul = self.split.get();
575        assert!(!mul.is_empty());
576        let n = self.output_basis().get().len();
577        let pre = pre_image(n, &unlab_f);
578        let mut res = Vec::new();
579        for pre_i in &pre {
580            let mut res_i: CsMat<u32> = CsMat::zero(mul[0].shape());
581            for &j in pre_i {
582                let mut mul_j = mul[j].clone();
583                mul_j *= unlab_c[j]; //
584                res_i = &res_i + &mul_j; // can be optimized
585            }
586            res.push(res_i.map(|&v| v as i64))
587        }
588        debug_assert_eq!(res.len(), self.output_basis().get().len());
589        debug_assert_eq!(
590            res[0].shape(),
591            (
592                self.split.left_basis().get().len(),
593                self.split.right_basis().get().len()
594            )
595        );
596        res
597    }
598}
599
600//
601#[derive(Clone, Debug)]
602pub struct InvariantClasses<F: Flag>(Unlabel<F>);
603
604impl<F: Flag> Savable<Vec<usize>, F> for InvariantClasses<F> {
605    fn filename(&self) -> String {
606        format!(
607            "invariant_classes_{}_id_{}_basis_{}{}",
608            self.0.size,
609            self.0.unlabeling.flag,
610            self.0.unlabeling.basis.size,
611            self.0.unlabeling.basis.t.to_string_suffix()
612        )
613    }
614    fn create(&self) -> Vec<usize> {
615        let unlabeling = self.0.unlabeling;
616        let t = unlabeling.input_type();
617        let flags = Basis::<F>::new(self.0.size).with_type(t).get();
618        invariant_classes(&unlabeling.eta(), unlabeling.basis.size, &flags)
619    }
620}
621
622#[derive(Clone, Copy, Debug)]
623pub struct ReducedByInvariant<F: Flag>(MulAndUnlabel<F>);
624
625impl<F: Flag> Savable<(Vec<CsMat<i64>>, Vec<CsMat<i64>>), F> for ReducedByInvariant<F> {
626    fn filename(&self) -> String {
627        format!("reduced_{}", self.0.filename())
628    }
629    fn create(&self) -> (Vec<CsMat<i64>>, Vec<CsMat<i64>>) {
630        let class = self.0.invariant_classes().get();
631        let (invariant_mat, antiinvariant_mat) = class_matrices(&class);
632        let mul_and_unlabel = self.0.get();
633        let mut res_inv = Vec::with_capacity(mul_and_unlabel.len());
634        let mut res_anti = Vec::with_capacity(mul_and_unlabel.len());
635        for m in mul_and_unlabel {
636            let invariant = &(&invariant_mat.transpose_view() * &m) * &invariant_mat;
637            let antiinvariant = if antiinvariant_mat.cols() == 0 {
638                CsMat::zero((0, 0)) // avoiding a small bug of sprs
639            } else {
640                &(&antiinvariant_mat.transpose_view() * &m) * &antiinvariant_mat
641            };
642            res_inv.push(invariant);
643            res_anti.push(antiinvariant);
644        }
645        (res_inv, res_anti)
646    }
647}
648
649// Workaround to give Basis and Unlabeling the Copy trait
650// (derive(Copy) does not derive the right bound when working
651// with PhantomData)
652impl<F: Flag> Copy for Type<F> {}
653impl<F: Flag> Copy for Unlabeling<F> {}
654impl<F: Flag> Copy for Basis<F> {}
655impl<F: Flag> Copy for SplitCount<F> {}
656impl<F: Flag> Copy for MulAndUnlabel<F> {}
657// =======================
658
659/// # Defining a quantum flags from a specified basis.
660impl<F: Flag> Basis<F> {
661    /// Sum of all flags of the basis.
662    /// This is an expression of the 1 of the flag algebra.
663    /// ```
664    /// use flag_algebra::*;
665    /// use flag_algebra::flags::Graph;
666    ///
667    /// let b = Basis::new(2);
668    /// let one = b.one();
669    /// let other: QFlag<i64, Graph> = b.random();
670    /// assert_eq!(&one * &other, other.expand(Basis::new(4)));
671    /// ```
672    pub fn one<N>(self) -> QFlag<N, F>
673    where
674        N: Num + Clone,
675    {
676        assert!(F::HEREDITARY || self.size == self.t.size);
677        let n = self.get().len();
678        QFlag {
679            basis: self,
680            data: Array::from_elem(n, N::one()),
681            scale: 1,
682            expr: Expr::FromIndicator(Rc::new(|_, _| true), self),
683        }
684    }
685    /// The zero vector in the specified basis.
686    /// ```
687    /// use flag_algebra::*;
688    /// use flag_algebra::flags::Graph;
689    ///
690    /// let basis = Basis::new(3);
691    /// let x: QFlag<i64, Graph> = basis.random();
692    /// assert_eq!(basis.zero() + &x, x);
693    /// ```
694    pub fn zero<N>(self) -> QFlag<N, F>
695    where
696        N: Num + Clone,
697    {
698        QFlag {
699            basis: self,
700            data: Array::zeros(self.get().len()),
701            scale: 1,
702            expr: Expr::Zero,
703        }
704    }
705    pub(crate) fn qflag_from_vec<N>(self, vec: Vec<N>) -> QFlag<N, F> {
706        assert_eq!(self.get().len(), vec.len());
707        QFlag {
708            basis: self,
709            data: Array::from(vec),
710            scale: 1,
711            expr: Expr::unknown(format!("from_vec({})", self.print_concise())),
712        }
713    }
714    /// Return the formal sum of the flags of the basis `self`
715    /// that satisfies some predicate `f`.
716    ///
717    /// The predicate `f` takes two arguments `g` and `sigma`, where `g` is a reference to
718    /// the flag and `sigma` is the size of the labeled part.
719    /// ```
720    /// use flag_algebra::*;
721    /// use flag_algebra::flags::Graph;
722    ///
723    /// // Sum of graphs of size 3 with an even number of edges
724    /// let b = Basis::<Graph>::new(3);
725    /// let sum = b.qflag_from_indicator(|g, _| g.edges().count() % 2 == 0 );
726    ///
727    /// let e3: QFlag<f64, Graph> = flag(&Graph::new(3, &[]));
728    /// let p3 = flag(&Graph::new(3, &[(0, 1), (1, 2)]));
729    /// assert_eq!(sum, e3 + &p3);
730    ///
731    /// /// Sum of the graphs of size 3 rooted on one vertex v
732    /// /// where v has degree at least 1
733    /// let t: Type<Graph> = Type::from_flag(&Graph::new(1, &[])); // Type for one vertex
734    /// let basis = Basis::new(3).with_type(t);
735    /// let sum: QFlag<f64, Graph> = basis.qflag_from_indicator(move |g, _| g.edge(0, 1) || g.edge(0, 2) );
736    /// ```
737    pub fn qflag_from_indicator<N, P>(self, predicate: P) -> QFlag<N, F>
738    where
739        P: Fn(&F, usize) -> bool + 'static,
740        N: One + Zero,
741    {
742        let indicator_rs: IndicatorFn<F> = Rc::new(move |a, b| predicate(a, b));
743        self.qflag_from_indicator_rc(indicator_rs)
744    }
745    pub(super) fn qflag_from_indicator_rc<N>(self, predicate: IndicatorFn<F>) -> QFlag<N, F>
746    where
747        N: One + Zero,
748    {
749        let vec: Vec<_> = self
750            .get()
751            .iter()
752            .map(|flag| {
753                if predicate(flag, self.t.size) {
754                    N::one()
755                } else {
756                    N::zero()
757                }
758            })
759            .collect();
760        QFlag {
761            basis: self,
762            data: Array::from(vec),
763            scale: 1,
764            expr: Expr::FromIndicator(predicate, self),
765        }
766    }
767    /// Return the formal sum of `f(g)*g` on the flags `g` of the basis `self`.
768    /// The second parameter of `f` is the size of the type of `g`.
769    /// ```
770    /// use flag_algebra::*;
771    /// use flag_algebra::flags::Graph;
772    ///
773    /// // Sum of graphs of size 3 weighted by their number of edges
774    /// let b = Basis::<Graph>::new(3);
775    /// let sum: QFlag<f64, Graph>  = b.qflag_from_coeff(|g, _| g.edges().count() as f64 );
776    /// ```
777
778    pub fn qflag_from_coeff<N, M, P>(self, f: P) -> QFlag<N, F>
779    where
780        P: Fn(&F, usize) -> M + 'static,
781        M: Into<N>,
782    {
783        let rc_f: CoefficientFn<F, N> = Rc::new(move |a, b| f(a, b).into());
784        self.qflag_from_coeff_rc(rc_f)
785    }
786    pub(crate) fn qflag_from_coeff_rc<N>(&self, f: CoefficientFn<F, N>) -> QFlag<N, F> {
787        let vec: Vec<_> = self.get().iter().map(|g| f(g, self.t.size)).collect();
788        QFlag {
789            basis: *self,
790            data: Array::from(vec),
791            scale: 1,
792            expr: Expr::FromFunction(f, *self),
793        }
794    }
795    pub fn random<N>(self) -> QFlag<N, F>
796    where
797        N: From<i16>,
798    {
799        let data: Vec<_> = (0..self.get().len())
800            .map(|_| {
801                let x: i16 = rand::random();
802                N::from(x)
803            })
804            .collect();
805        QFlag {
806            basis: self,
807            data: Array::from(data),
808            scale: 1,
809            expr: Expr::unknown(format!("random({})", self.print_concise())),
810        }
811    }
812    pub(crate) fn flag_from_id<N>(self, id: usize) -> QFlag<N, F>
813    where
814        N: Num + Clone,
815    {
816        self.flag_from_id_with_base_size(id, self.get().len())
817    }
818    pub(crate) fn flag_from_id_with_base_size<N>(self, id: usize, size: usize) -> QFlag<N, F>
819    where
820        N: Num + Clone,
821    {
822        let mut res = QFlag {
823            basis: self,
824            data: Array::zeros(size),
825            scale: 1,
826            expr: Expr::Flag(id, self),
827        };
828        res.data[id] = N::one();
829        res
830    }
831    /// Create a quantum flag containing exactly one flag.
832    pub fn flag<N>(self, f: &F) -> QFlag<N, F>
833    where
834        N: Num + Clone,
835    {
836        assert_eq!(self.size, f.size());
837        let flags = self.get();
838        let mut data = Array::zeros(flags.len());
839        let f1 = f.canonical_typed(self.t.size);
840        let id = flags.binary_search(&f1).expect("Flag not found in basis");
841        data[id] = N::one();
842        QFlag {
843            basis: self,
844            data,
845            scale: 1,
846            expr: Expr::Flag(id, self),
847        }
848    }
849    /// Returns the list of identifiers of all Square-and-unlabel operators
850    /// that can be used in Cauchy-Schwarz inequalities for a problem on the basis `self`.
851    pub fn all_cs(&self) -> Vec<MulAndUnlabel<F>> {
852        let mut res = Vec::new();
853        let n = self.size;
854        // m: size of a cs basis
855        for m in (n + self.t.size) / 2 + 1..=(2 * n - 1) / 2 {
856            let sigma = 2 * m - n;
857            let unlab_basis = Self::new(sigma).with_type(self.t);
858            for unlab_id in 0..unlab_basis.get().len() {
859                let unlabeling = Unlabeling::new(unlab_basis, unlab_id);
860                let input_basis = Self::new(m).with_type(unlabeling.input_type());
861                let split = SplitCount::from_input(&input_basis, &input_basis);
862                res.push(MulAndUnlabel { split, unlabeling })
863            }
864        }
865        res
866    }
867}
868
869#[cfg(test)]
870mod tests {
871    use super::*;
872    use crate::flag::SubClass;
873    use crate::flags::*;
874
875    #[test]
876    fn basis() {
877        assert_eq!(Basis::<Graph>::new(5).get().len(), 34);
878        assert_eq!(Basis::<Graph>::make(3, Type::new(1, 0)).get().len(), 6);
879        assert_eq!(Basis::<Graph>::make(4, Type::new(2, 1)).get().len(), 20);
880        //
881        assert_eq!(Basis::<OrientedGraph>::new(3).get().len(), 7);
882        assert_eq!(Basis::<OrientedGraph>::new(5).get().len(), 582);
883        assert_eq!(
884            Basis::<OrientedGraph>::make(3, Type::new(1, 0)).get().len(),
885            15
886        );
887        assert_eq!(
888            Basis::<OrientedGraph>::make(4, Type::new(2, 0)).get().len(),
889            126
890        );
891        //
892        assert_eq!(Basis::<DirectedGraph>::new(2).get().len(), 3);
893        assert_eq!(Basis::<DirectedGraph>::new(3).get().len(), 16);
894        //
895        assert_eq!(
896            Basis::<SubClass<OrientedGraph, TriangleFree>>::new(3)
897                .get()
898                .len(),
899            6
900        );
901        assert_eq!(
902            Basis::<SubClass<OrientedGraph, TriangleFree>>::new(5)
903                .get()
904                .len(),
905            317
906        );
907        assert_eq!(
908            Basis::<SubClass<OrientedGraph, TriangleFree>>::make(3, Type::new(2, 1))
909                .get()
910                .len(),
911            8
912        );
913    }
914    #[test]
915    fn splitcount() {
916        assert_eq!(56, SplitCount::<Graph>::make(5, 7, Type::new(2, 1)).denom());
917        assert_eq!(3, SplitCount::<Graph>::make(3, 4, Type::new(2, 1)).denom());
918        let _ = SplitCount::<Graph>::make(3, 2, Type::empty()).get();
919        let _ = SplitCount::<Graph>::make(2, 3, Type::empty()).get();
920        let _ = SplitCount::<Graph>::make(2, 3, Type::new(1, 0)).get();
921        //
922        let _ = SplitCount::<OrientedGraph>::make(2, 3, Type::new(1, 0)).get();
923    }
924    #[test]
925    fn subflagcount() {
926        assert_eq!(
927            45,
928            SplitCount::<Graph>::make(4, 10, Type::new(2, 0)).denom()
929        );
930        let _ = SubflagCount::<Graph>::make(2, 3, Type::new(1, 0)).get();
931        let _ = SubflagCount::<Graph>::make(3, 4, Type::new(2, 1)).get();
932        let _ = SubflagCount::<Graph>::make(5, 5, Type::empty()).get();
933        let _ = SubflagCount::<Graph>::make(3, 5, Type::new(1, 0)).get();
934    }
935    #[test]
936    fn unlabel() {
937        let t = Type::new(3, 1);
938        let unlabeling = Unlabeling::<Graph>::total(t);
939        let size = 5;
940        assert_eq!((Unlabel { unlabeling, size }).denom(), 60);
941        let _ = (Unlabel { unlabeling, size }).get();
942        //
943        let b = Basis::new(3).with_type(Type::new(2, 1));
944        let unlabeling = Unlabeling::<Graph>::new(b, 0);
945        let _ = (Unlabel { unlabeling, size }).get();
946    }
947    #[test]
948    fn mulandunlabeling() {
949        let t = Type::new(2, 1);
950        let unlabeling = Unlabeling::<Graph>::total(t);
951        let split = SplitCount::make(3, 2, t);
952        let _mau = (MulAndUnlabel { split, unlabeling }).get();
953    }
954    #[test]
955    fn type_iterator() {
956        assert_eq!(Type::<Graph>::types_with_size(4).count(), 11);
957    }
958    //     #[test]
959    //     fn unlabeling_eta() {
960    //         let b = Basis::<Graph>::new(5).with_type(Type::new(3, 1));
961    //         let unlabeling = Unlabeling::new(b, 1);
962    //         let eta = unlabeling.eta();
963    //         let g = &unlabeling.basis.get()[unlabeling.flag];
964    //         let t = unlabeling.basis.t;
965    //         assert_eq!(g.induce(&eta), Basis::new(t.size).get()[t.id])
966    //     }
967}