Skip to main content

Rustb/
output.rs

1//!这个模块是用来输出各种标准格式的, 包括
2//!
3//!wannier90_hr.dat 格式
4//!
5//!wannier90_centres.xyz 格式
6//!
7//!wannier90.win 格式
8//!
9//!整合的 wannier90 格式
10//!
11//!POSCAR 格式
12use crate::Model;
13use crate::kpath::*;
14use crate::model::{Dimension, find_R};
15use crate::error::{Result, TbError};
16use crate::kpoints::gen_kmesh;
17use crate::math::comm;
18use crate::solve_ham::solve;
19use ndarray::concatenate;
20use ndarray::linalg::kron;
21use ndarray::prelude::*;
22use ndarray::*;
23use ndarray_linalg::conjugate;
24use ndarray_linalg::*;
25use ndarray_linalg::{Eigh, UPLO};
26use num_complex::Complex;
27use rayon::prelude::*;
28use std::f64::consts::PI;
29use std::fs::File;
30use std::io::Write;
31use std::ops::AddAssign;
32use std::ops::MulAssign;
33
34pub trait OutPut {
35    /// 这个函数是用来将 tight-binding 模型输出到 wannier90_hr.dat 格式的
36    fn output_hr(&self, path: &str, seedname: &str);
37    fn output_POSCAR(&self, path: &str);
38    fn output_win(&self, path: &str, seedname: &str);
39    fn output_xyz(&self, path: &str, seedname: &str);
40    fn show_band(&self, path: &Array2<f64>, label: &Vec<&str>, nk: usize, name: &str)
41    -> Result<()>;
42}
43
44impl OutPut for Model {
45    fn output_hr(&self, path: &str, seedname: &str) {
46        let n_R = self.hamR.nrows(); //length of hamR
47        let mut hr_name = String::new();
48        hr_name.push_str(path);
49        hr_name.push_str(seedname);
50        hr_name.push_str("_hr.dat");
51        let mut file = File::create(hr_name).expect("Unable to BAND.dat");
52        writeln!(file, "{}", self.nsta());
53        writeln!(file, "{}", n_R);
54        let mut weight = String::new();
55        let lines = n_R.div_euclid(15);
56        let last_lines = n_R % 15;
57        if lines != 0 {
58            for i in 0..lines {
59                weight.push_str(
60                    "    1    1    1    1    1    1    1    1    1    1    1    1    1    1    1\n",
61                );
62            }
63        }
64        for i in 0..last_lines {
65            weight.push_str("    1");
66        }
67        writeln!(file, "{}", weight);
68        //接下来我们进行数据的写入
69        match self.dim_r {
70            Dimension::zero => {
71                let mut s = String::new();
72                let ham = self.ham.slice(s![0, .., ..]);
73                for orb_2 in 0..self.nsta() {
74                    for orb_1 in 0..self.nsta() {
75                        s.push_str(&format!(
76                            "0    0    0    {:15.8}    {:15.8}\n",
77                            ham[[orb_1, orb_2]].re,
78                            ham[[orb_1, orb_2]].im
79                        ));
80                    }
81                }
82            }
83            Dimension::one => {
84                let max_R1 = self.hamR.outer_iter().map(|x| x[[0]].abs()).max().unwrap();
85                let mut s = String::new();
86                for i in -max_R1..max_R1 {
87                    match (
88                        find_R(&self.hamR, &array![i as isize]),
89                        find_R(&self.hamR, &(-array![i as isize])),
90                    ) {
91                        (Some(r0), _) => {
92                            let ham = self.ham.slice(s![r0, .., ..]);
93                            for orb_2 in 0..self.nsta() {
94                                for orb_1 in 0..self.nsta() {
95                                    s.push_str(&format!(
96                                        "{:>3}    0    0    {:>3}    {:>3}    {:>15.8}    {:>15.8}\n",
97                                        i,
98                                        orb_1,
99                                        orb_2,
100                                        ham[[orb_1, orb_2]].re,
101                                        ham[[orb_1, orb_2]].im
102                                    ));
103                                }
104                            }
105                        }
106                        (None, Some(r0)) => {
107                            let ham = self.ham.slice(s![r0, .., ..]);
108                            for orb_2 in 0..self.nsta() {
109                                for orb_1 in 0..self.nsta() {
110                                    s.push_str(&format!(
111                                        "{:>3}    0    0    {:>3}    {:>3}    {:>15.8}    {:>15.8}\n",
112                                        i,
113                                        orb_1,
114                                        orb_2,
115                                        ham[[orb_1, orb_2]].re,
116                                        -ham[[orb_1, orb_2]].im
117                                    ));
118                                }
119                            }
120                        }
121                        (None, None) => {}
122                    }
123                }
124                writeln!(file, "{}", s);
125            }
126            Dimension::two => {
127                let max_values = self
128                    .hamR
129                    .fold_axis(Axis(0), isize::min_value(), |max, &value| {
130                        *max.max(&value.abs())
131                    });
132                let mut s = String::new();
133                for R1 in -max_values[[0]]..max_values[[0]] {
134                    for R2 in -max_values[[1]]..max_values[[1]] {
135                        let R0 = array![R1 as isize, R2 as isize];
136                        let R0_inv = -array![R1 as isize, R2 as isize];
137                        match (find_R(&self.hamR, &R0), find_R(&self.hamR, &R0_inv)) {
138                            (Some(r0), _) => {
139                                let ham = self.ham.slice(s![r0, .., ..]);
140                                for orb_2 in 0..self.nsta() {
141                                    for orb_1 in 0..self.nsta() {
142                                        s.push_str(&format!(
143                                            "{:>3}  {:>3}    0    {:>3}    {:>3}    {:>15.8}    {:>15.8}\n",
144                                            R1,
145                                            R2,
146                                            orb_1,
147                                            orb_2,
148                                            ham[[orb_1, orb_2]].re,
149                                            ham[[orb_1, orb_2]].im
150                                        ));
151                                    }
152                                }
153                            }
154                            (None, Some(r0)) => {
155                                let ham = self.ham.slice(s![r0, .., ..]);
156                                for orb_2 in 0..self.nsta() {
157                                    for orb_1 in 0..self.nsta() {
158                                        s.push_str(&format!(
159                                            "{:>3}  {:>3}    0    {:>3}    {:>3}    {:>15.8}    {:>15.8}\n",
160                                            R1,
161                                            R2,
162                                            orb_1,
163                                            orb_2,
164                                            ham[[orb_1, orb_2]].re,
165                                            -ham[[orb_1, orb_2]].im
166                                        ));
167                                    }
168                                }
169                            }
170                            (None, None) => {}
171                        }
172                    }
173                }
174                writeln!(file, "{}", s);
175            }
176            Dimension::three => {
177                let max_values = self
178                    .hamR
179                    .fold_axis(Axis(0), isize::min_value(), |max, &value| {
180                        *max.max(&value.abs())
181                    });
182                let mut s = String::new();
183                for R1 in -max_values[[0]]..max_values[[0]] {
184                    for R2 in -max_values[[1]]..max_values[[1]] {
185                        for R3 in -max_values[[2]]..max_values[[2]] {
186                            let R0 = array![R1 as isize, R2 as isize, R3 as isize];
187                            let R0_inv = -array![R1 as isize, R2 as isize, R3 as isize];
188                            match (find_R(&self.hamR, &R0), find_R(&self.hamR, &R0_inv)) {
189                                (Some(r0), _) => {
190                                    let ham = self.ham.slice(s![r0, .., ..]);
191                                    for orb_2 in 0..self.nsta() {
192                                        for orb_1 in 0..self.nsta() {
193                                            s.push_str(&format!(
194                                                "{:>3}  {:>3}  {:>3}    {:>3}    {:>3}    {:>15.8}    {:>15.8}\n",
195                                                R1,
196                                                R2,
197                                                R3,
198                                                orb_1,
199                                                orb_2,
200                                                ham[[orb_1, orb_2]].re,
201                                                ham[[orb_1, orb_2]].im
202                                            ));
203                                        }
204                                    }
205                                }
206                                (None, Some(r0)) => {
207                                    let ham = self.ham.slice(s![r0, .., ..]);
208                                    for orb_2 in 0..self.nsta() {
209                                        for orb_1 in 0..self.nsta() {
210                                            s.push_str(&format!(
211                                                "{:>3}  {:>3}  {:>3}    {:>3}    {:>3}    {:>15.8}    {:>15.8}\n",
212                                                R1,
213                                                R2,
214                                                R3,
215                                                orb_1,
216                                                orb_2,
217                                                ham[[orb_1, orb_2]].re,
218                                                -ham[[orb_1, orb_2]].im
219                                            ));
220                                        }
221                                    }
222                                }
223                                (None, None) => {}
224                            }
225                        }
226                    }
227                }
228                writeln!(file, "{}", s);
229            }
230        }
231    }
232
233    fn output_POSCAR(&self, path: &str) {
234        let mut name = String::new();
235        name.push_str(path);
236        name.push_str("POSCAR");
237        let mut file = File::create(&name).expect("Unable to BAND.dat");
238        writeln!(file, "Generate by Rustb");
239        writeln!(file, "1.0");
240        let s = match self.dim_r {
241            Dimension::three => {
242                let mut s = String::new();
243                s.push_str(&format!("    {:>15.8}    {:>15.8}    {:>15.8}\n    {:>15.8}    {:>15.8}    {:>15.8}\n    {:>15.8}    {:>15.8}    {:>15.8}",self.lat[[0,0]],self.lat[[0,1]],self.lat[[0,2]],self.lat[[1,0]],self.lat[[1,1]],self.lat[[1,2]],self.lat[[2,0]],self.lat[[2,1]],self.lat[[2,2]]));
244                s
245            }
246            Dimension::two => {
247                let mut s = String::new();
248                s.push_str(&format!("    {:>15.8}    {:>15.8}    {:>15.8}\n    {:>15.8}    {:>15.8}    {:>15.8}\n    {:>15.8}    {:>15.8}    {:>15.8}",self.lat[[0,0]],self.lat[[0,1]],0.0,self.lat[[1,0]],self.lat[[1,1]],0.0,0.0,0.0,10.0));
249                s
250            }
251            Dimension::one => {
252                let mut s = String::new();
253                s.push_str(&format!("    {:>15.8}    {:>15.8}    {:>15.8}\n    {:>15.8}    {:>15.8}    {:>15.8}\n    {:>15.8}    {:>15.8}    {:>15.8}",self.lat[[0,0]],0.0,0.0,0.0,10.0,0.0,0.0,0.0,10.0));
254                s
255            }
256            Dimension::zero => {
257                panic!(
258                    "Wrong! for POSCAR output, the dim_r of the model must be 1, 2 or 3, but yours zero"
259                );
260            }
261        };
262        writeln!(file, "{}", s);
263        //开始弄atom
264        let mut atom_type = vec![];
265        let mut atom_num = vec![];
266        let mut new_atom_position: Vec<Vec<Array1<f64>>> = Vec::new();
267        for i in 0..self.natom() {
268            let mut have_atom = false;
269            for j in 0..atom_type.len() {
270                if self.atoms[i].atom_type() == atom_type[j] {
271                    have_atom = true;
272                    atom_num[j] += 1;
273                    new_atom_position[j].push(self.atom_position().row(i).to_owned());
274                }
275            }
276            if have_atom == false {
277                atom_num.push(1);
278                atom_type.push(self.atoms[i].atom_type());
279                new_atom_position.push(vec![self.atom_position().row(i).to_owned()]);
280            }
281        }
282        let mut s = String::new();
283        for i in 0..atom_type.len() {
284            s.push_str(&format!("   {}", atom_type[i]));
285        }
286        writeln!(file, "{}", s);
287        let mut s = String::new();
288        for i in 0..atom_type.len() {
289            s.push_str(&format!("{:>4}", atom_num[i]));
290        }
291        writeln!(file, "{}", s);
292        writeln!(file, "Direct");
293        let mut s = String::new();
294        for i in 0..atom_type.len() {
295            for j in 0..new_atom_position[i].len() {
296                let s = match self.dim_r {
297                    Dimension::three => {
298                        let mut s = String::new();
299                        s.push_str(&format!(
300                            "{:>15.8}   {:>15.8}   {:>15.8}",
301                            new_atom_position[i][j][[0]],
302                            new_atom_position[i][j][[1]],
303                            new_atom_position[i][j][[2]]
304                        ));
305                        s
306                    }
307                    Dimension::two => {
308                        let mut s = String::new();
309                        s.push_str(&format!(
310                            "{:>15.8}   {:>15.8}   {:>15.8}",
311                            new_atom_position[i][j][[0]],
312                            new_atom_position[i][j][[1]],
313                            0.0
314                        ));
315                        s
316                    }
317                    Dimension::one => {
318                        let mut s = String::new();
319                        s.push_str(&format!(
320                            "{:>15.8}   {:>15.8}   {:>15.8}",
321                            new_atom_position[i][j][[0]],
322                            0.0,
323                            0.0
324                        ));
325                        s
326                    }
327                    Dimension::zero => {
328                        panic!(
329                            "Wrong! for POSCAR output, the dim_r of the model must be 1, 2 or 3, but yours zero"
330                        );
331                    }
332                };
333                writeln!(file, "{}", s);
334            }
335        }
336    }
337
338    fn output_win(&self, path: &str, seedname: &str) {
339        //!这个是用来输出 win 文件的. 这里projection 需要人为添加, 因为没有保存相关的projection 数据
340        let mut name = String::new();
341        name.push_str(path);
342        name.push_str(seedname);
343        name.push_str(".win");
344        let mut file = File::create(name).expect("Wrong, can't create seedname.win");
345        writeln!(file, "begin atoms_cart");
346        for at in self.atoms.iter() {
347            let atom_position = at.position();
348            match self.dim_r {
349                Dimension::three => {
350                    writeln!(
351                        file,
352                        "{}  {:>10.6}  {:>10.6}  {:>10.6}",
353                        at.atom_type(),
354                        atom_position[0],
355                        atom_position[1],
356                        atom_position[1]
357                    );
358                }
359                Dimension::two => {
360                    writeln!(
361                        file,
362                        "{}  {:>10.6}  {:>10.6}  {:>10.6}",
363                        at.atom_type(),
364                        atom_position[0],
365                        atom_position[1],
366                        0.0
367                    );
368                }
369                Dimension::one => {
370                    writeln!(
371                        file,
372                        "{}  {:>10.6}  {:>10.6}  {:>10.6}",
373                        at.atom_type(),
374                        atom_position[0],
375                        0.0,
376                        0.0
377                    );
378                }
379                Dimension::zero => panic!("Wrong, your model's dim_r is 0, not 1,2 or 3"),
380            }
381        }
382        writeln!(file, "end atoms_cart");
383        writeln!(file, "\n");
384        writeln!(file, "begin unit_cell_cart");
385        match self.dim_r {
386            Dimension::three => {
387                let mut s = String::new();
388                for i in 0..3 {
389                    for j in 0..3 {
390                        s.push_str(&format!("{:>10.6}  ", self.lat[[i, j]]));
391                    }
392                    writeln!(file, "{}", s);
393                }
394            }
395            Dimension::two => {
396                let mut s = String::new();
397                for i in 0..2 {
398                    for j in 0..2 {
399                        s.push_str(&format!("{:>10.6}  ", self.lat[[i, j]]));
400                    }
401                    s.push_str("   0.000000");
402                    writeln!(file, "{}", s);
403                }
404                writeln!(file, "   0.000000     0.000000     1.000000");
405            }
406            Dimension::one => {
407                let mut s = String::new();
408                s.push_str(&format!("{:>10.6}  ", self.lat[[0, 0]]));
409                s.push_str("   0.000000     0.000000");
410                writeln!(file, "{}", s);
411                writeln!(file, "   0.000000     0.000000     1.000000");
412                writeln!(file, "   0.000000     0.000000     1.000000");
413            }
414            Dimension::zero => {
415                panic!(
416                    "Wrong! Using output win file, the dim_r of model mut be 1, 2, or 3, but yours 0"
417                )
418            }
419        }
420        writeln!(file, "end unit_cell_cart");
421        writeln!(file, "\n");
422        //还差投影轨道
423        writeln!(file, "begin projections");
424        writeln!(file, "end projections");
425    }
426    fn output_xyz(&self, path: &str, seedname: &str) {
427        //!这个是用来输出 xyz 文件的. 这里projection 需要人为添加, 因为没有保存相关的projection 数据
428        let mut name = String::new();
429        name.push_str(path);
430        name.push_str(seedname);
431        name.push_str("_centres.xyz");
432        let mut file = File::create(name).expect("Wrong, can't create seedname.win");
433        let number = self.nsta() + self.natom();
434        let orb_real = self.orb.dot(&self.lat);
435        let atom_position_real = self.atom_position().dot(&self.lat);
436        writeln!(file, "{}", number);
437        writeln!(file, "Wannier centres, written by Rustb");
438        let mut s = match self.dim_r {
439            Dimension::three => {
440                let mut s = String::new();
441                for i in 0..self.norb() {
442                    s.push_str(&format!(
443                        "X{:>20.8}{:>17.8}{:>17.8}\n",
444                        orb_real[[i, 0]],
445                        orb_real[[i, 1]],
446                        orb_real[[i, 2]]
447                    ));
448                }
449                if self.spin {
450                    for i in 0..self.norb() {
451                        s.push_str(&format!(
452                            "X{:>20.8}{:>17.8}{:>17.8}\n",
453                            orb_real[[i, 0]],
454                            orb_real[[i, 1]],
455                            orb_real[[i, 2]]
456                        ));
457                    }
458                }
459                for i in 0..self.natom() - 1 {
460                    s.push_str(&format!(
461                        "{}{:>19.8}{:>17.8}{:>17.8}\n",
462                        self.atoms[i].atom_type(),
463                        atom_position_real[[i, 0]],
464                        atom_position_real[[i, 1]],
465                        atom_position_real[[i, 2]]
466                    ));
467                }
468                let i = self.natom() - 1;
469                s.push_str(&format!(
470                    "{}{:>19.8}{:>17.8}{:>17.8}",
471                    self.atoms[i].atom_type(),
472                    atom_position_real[[i, 0]],
473                    atom_position_real[[i, 1]],
474                    atom_position_real[[i, 2]]
475                ));
476                s
477            }
478            Dimension::two => {
479                let mut s = String::new();
480                for i in 0..self.norb() {
481                    s.push_str(&format!(
482                        "X{:>20.8}{:>17.8}       0.00000000\n",
483                        orb_real[[i, 0]],
484                        orb_real[[i, 1]]
485                    ));
486                }
487                if self.spin {
488                    for i in 0..self.norb() {
489                        s.push_str(&format!(
490                            "X{:>20.8}{:>17.8}       0.00000000\n",
491                            orb_real[[i, 0]],
492                            orb_real[[i, 1]]
493                        ));
494                    }
495                }
496                for i in 0..self.natom() - 1 {
497                    s.push_str(&format!(
498                        "{}{:>19.8}{:>17.8}       0.00000000\n",
499                        self.atoms[i].atom_type(),
500                        atom_position_real[[i, 0]],
501                        atom_position_real[[i, 1]]
502                    ));
503                }
504                let i = self.natom() - 1;
505                s.push_str(&format!(
506                    "{}{:>19.8}{:>17.8}       0.00000000",
507                    self.atoms[i].atom_type(),
508                    atom_position_real[[i, 0]],
509                    atom_position_real[[i, 1]]
510                ));
511                s
512            }
513            Dimension::one => {
514                let mut s = String::new();
515                for i in 0..self.norb() {
516                    s.push_str(&format!(
517                        "X{:>20.8}       0.00000000       0.00000000\n",
518                        orb_real[[i, 0]]
519                    ));
520                }
521                if self.spin {
522                    for i in 0..self.norb() {
523                        s.push_str(&format!(
524                            "X{:>20.8}       0.00000000       0.00000000\n",
525                            orb_real[[i, 0]]
526                        ));
527                    }
528                }
529                for i in 0..self.natom() - 1 {
530                    s.push_str(&format!(
531                        "{}{:>19.8}       0.00000000       0.00000000\n",
532                        self.atoms[i].atom_type(),
533                        atom_position_real[[i, 0]]
534                    ));
535                }
536                let i = self.natom() - 1;
537                s.push_str(&format!(
538                    "{}{:>19.8}       0.00000000       0.00000000",
539                    self.atoms[i].atom_type(),
540                    atom_position_real[[i, 0]]
541                ));
542                s
543            }
544            Dimension::zero => {
545                panic!("Wrong!, the dim_r must be 1,2 or 3, but yours 0");
546            }
547        };
548        writeln!(file, "{}", s);
549    }
550
551    ///这个函数是用来快速画能带图的, 用python画图, 因为Rust画图不太方便.
552    #[allow(non_snake_case)]
553    fn show_band(
554        &self,
555        path: &Array2<f64>,
556        label: &Vec<&str>,
557        nk: usize,
558        name: &str,
559    ) -> Result<()> {
560        use gnuplot::AutoOption::*;
561        use gnuplot::AxesCommon;
562        use gnuplot::Tick::*;
563        use gnuplot::{Caption, Color, Figure, Font, LineStyle, Solid};
564        use std::fs::create_dir_all;
565        use std::path::Path;
566        if path.len_of(Axis(0)) != label.len() {
567            panic!(
568                "Error, the path's length {} and label's length {} must be equal!",
569                path.len_of(Axis(0)),
570                label.len()
571            )
572        }
573        let (k_vec, k_dist, k_node) = self.k_path(&path, nk)?;
574        let eval = self.solve_band_all_parallel(&k_vec);
575        create_dir_all(name).map_err(|e| TbError::DirectoryCreation {
576            path: name.to_string(),
577            message: e.to_string(),
578        })?;
579        let mut name0 = String::new();
580        name0.push_str("./");
581        name0.push_str(&name);
582        let name = name0;
583        let mut band_name = name.clone();
584        band_name.push_str("/BAND.dat");
585        let band_name = Path::new(&band_name);
586        let mut file = File::create(band_name).expect("Unable to BAND.dat");
587        for i in 0..nk {
588            let mut s = String::new();
589            let aa = format!("{:.6}", k_dist[[i]]);
590            s.push_str(&aa);
591            for j in 0..self.nsta() {
592                if eval[[i, j]] >= 0.0 {
593                    s.push_str("     ");
594                } else {
595                    s.push_str("    ");
596                }
597                let aa = format!("{:.6}", eval[[i, j]]);
598                s.push_str(&aa);
599            }
600            writeln!(file, "{}", s)?;
601        }
602        let mut k_name = name.clone();
603        k_name.push_str("/KLABELS");
604        let k_name = Path::new(&k_name);
605        let mut file = File::create(k_name).expect("Unable to create KLBAELS"); //写下高对称点的位置
606        for i in 0..path.len_of(Axis(0)) {
607            let mut s = String::new();
608            let aa = format!("{:.6}", k_node[[i]]);
609            s.push_str(&aa);
610            s.push_str("      ");
611            s.push_str(&label[i]);
612            writeln!(file, "{}", s)?;
613        }
614        let mut py_name = name.clone();
615        py_name.push_str("/print.py");
616        let py_name = Path::new(&py_name);
617        let mut file = File::create(py_name).expect("Unable to create print.py");
618        writeln!(
619            file,
620            "import numpy as np\nimport matplotlib.pyplot as plt\ndata=np.loadtxt('BAND.dat')\nk_nodes=[]\nlabel=[]\nf=open('KLABELS')\nfor i in f.readlines():\n    k_nodes.append(float(i.split()[0]))\n    label.append(i.split()[1])\nfig,ax=plt.subplots()\nax.plot(data[:,0],data[:,1:],c='b')\nfor x in k_nodes:\n    ax.axvline(x,c='k')\nax.set_xticks(k_nodes)\nax.set_xticklabels(label)\nax.set_xlim([0,k_nodes[-1]])\nfig.savefig('band.pdf')"
621        );
622        //开始绘制pdf图片
623        let mut fg = Figure::new();
624        let x: Vec<f64> = k_dist.to_vec();
625        let axes = fg.axes2d();
626        for i in 0..self.nsta() {
627            let y: Vec<f64> = eval.slice(s![.., i]).to_owned().to_vec();
628            axes.lines(&x, &y, &[Color("black"), LineStyle(Solid)]);
629        }
630        let axes = axes.set_x_range(Fix(0.0), Fix(k_node[[k_node.len() - 1]]));
631        let label = label.clone();
632        let mut show_ticks = Vec::new();
633        for i in 0..k_node.len() {
634            let A = k_node[[i]];
635            let B = label[i];
636            show_ticks.push(Major(A, Fix(B)));
637        }
638        axes.set_x_ticks_custom(
639            show_ticks.into_iter(),
640            &[],
641            &[Font("Times New Roman", 24.0)],
642        );
643
644        let k_node = k_node.to_vec();
645        let mut pdf_name = name.clone();
646        pdf_name.push_str("/plot.pdf");
647        fg.set_terminal("pdfcairo", &pdf_name);
648        fg.show();
649        Ok(())
650    }
651}
652
653pub fn draw_heatmap<A: Data<Elem = f64>>(data: &ArrayBase<A, Ix2>, name: &str) {
654    //!这个函数是用来画热图的, 给定一个二维矩阵, 会输出一个像素图片
655    use gnuplot::{AutoOption::Fix, AxesCommon, Figure, HOT, RAINBOW};
656    let mut fg = Figure::new();
657    let (height, width): (usize, usize) = (data.shape()[0], data.shape()[1]);
658    let mut heatmap_data = vec![];
659
660    for j in 0..width {
661        for i in 0..height {
662            heatmap_data.push(data[(i, j)]);
663        }
664    }
665    let axes = fg.axes2d();
666    axes.set_title("Heatmap", &[]);
667    axes.set_cb_label("Values", &[]);
668    axes.set_palette(RAINBOW);
669    axes.image(heatmap_data.iter(), width, height, None, &[]);
670    let size = data.shape();
671    let axes = axes.set_x_range(Fix(0.0), Fix((size[0] - 1) as f64));
672    let axes = axes.set_y_range(Fix(0.0), Fix((size[1] - 1) as f64));
673    let axes = axes.set_aspect_ratio(Fix(1.0));
674    fg.set_terminal("pdfcairo", name);
675    fg.show().expect("Unable to draw heatmap");
676}