qubic_engine 0.1.7

qubic engine
Documentation
use sqlite::{open, Connection};

use crate::ai::u2vec;
use crate::board::Board;
use crate::train;
use crate::{
    ml::{create_batch, Tensor},
    train::Transition,
    utills::rand::get_random_usize,
};

pub struct BoardDB {
    conn: Connection,
    batch_size: usize,
    pub batch_num: usize,
    lambda: f32,
}

impl BoardDB {
    pub fn new(s: &str, batch_size: usize) -> Self {
        let conn = open(s).unwrap();

        let query = "
            create table if not exists board_record (
                att integer,
                def integer,
                flag integer,
                val real
            )
        ";

        conn.execute(query).unwrap();
        let mut db = BoardDB {
            conn: conn,
            batch_size: batch_size,
            batch_num: 0,
            lambda: 0.0,
        };

        if batch_size == 0 {
            let count = db.get_count();
            db.batch_size = count;
        }

        db.set_batch_num();

        return db;
    }

    pub fn set_lambda(&mut self, lambda: f32) {
        self.lambda = lambda;
    }

    pub fn set_batch_num(&mut self) {
        let count = self.get_count();
        self.batch_num = count / self.batch_size;
    }

    pub fn get_batch_num(&self) -> usize {
        return self.batch_num;
    }

    pub fn begine(&self) {
        let _ = self.conn.execute("begine");
    }

    pub fn end(&self) {
        let _ = self.conn.execute("end");
    }

    pub fn add(&self, att: u64, def: u64, flag: i32, val: f32) {
        let query = format!(
            "
                insert into board_record(att, def, flag, val)
                values({}, {}, {}, {})",
            att as i64, def as i64, flag, val
        );

        self.conn.execute(query).unwrap();
    }

    pub fn concat(&self, other: BoardDB) {
        let query = format!(
            "
                select att, def, flag, val from board_record",
        );

        // let mut ts = Vec::new();
        let mut count = 0;

        self.begine();
        println!("merge BoardDB");
        other
            .conn
            .iterate(query, |pairs| {
                if count % 1000 == 0 {
                    println!("{count}");
                }
                count += 1;
                let row = pairs.get(0..4).unwrap();
                let att: i64 = row[0].1.unwrap().parse().unwrap();
                let att = att as u64;
                let def: i64 = row[1].1.unwrap().parse().unwrap();
                let def = def as u64;
                let flag: i32 = row[2].1.unwrap().parse().unwrap();
                let val: f32 = row[3].1.unwrap().parse().unwrap();
                self.add(att, def, flag, val);
                true
            })
            .unwrap();
        self.end();
    }

    pub fn get(&self, size: usize) -> Vec<(u64, u64, i32, f32)> {
        let query = format!(
            "
                select att, def, flag, val from board_record order by random() limit {}",
            size
        );

        let mut ts = Vec::new();

        self.conn
            .iterate(query, |pairs| {
                let row = pairs.get(0..4).unwrap();
                let att: i64 = row[0].1.unwrap().parse().unwrap();
                let att = att as u64;
                let def: i64 = row[1].1.unwrap().parse().unwrap();
                let def = def as u64;
                let flag: i32 = row[2].1.unwrap().parse().unwrap();
                let val: f32 = row[3].1.unwrap().parse().unwrap();
                ts.push((att, def, flag, val));
                true
            })
            .unwrap();
        return ts;
    }

    pub fn get_count(&self) -> usize {
        let query = "SELECT COUNT(*) FROM board_record";

        let mut count = 0;
        self.conn
            .iterate(query, |pairs| {
                for &(name, value) in pairs.iter() {
                    count = value.unwrap().parse().unwrap();
                }
                true
            })
            .unwrap();
        return count as usize;
    }

    pub fn get_batch(&self) -> Vec<Transition> {
        let query = format!(
            "
                select att, def, flag, val from board_record order by random() limit {}",
            self.batch_size
        );

        let mut ts = Vec::new();

        self.conn
            .iterate(query, |pairs| {
                let row = pairs.get(0..4).unwrap();
                let att: i64 = row[0].1.unwrap().parse().unwrap();
                let att = att as u64;
                let def: i64 = row[1].1.unwrap().parse().unwrap();
                let def = def as u64;
                let flag: i32 = row[2].1.unwrap().parse().unwrap();
                let val: f32 = row[3].1.unwrap().parse().unwrap();
                ts.push(Transition {
                    board: (att as u128) | ((def as u128) << 64),
                    result: flag,
                    t_val: val,
                });
                true
            })
            .unwrap();
        return ts;
    }

    pub fn get_all(&self) -> Vec<Transition> {
        let query = format!(
            "
                select att, def, flag, val from board_record",
        );

        let mut ts = Vec::new();

        self.conn
            .iterate(query, |pairs| {
                let row = pairs.get(0..4).unwrap();
                let att: i64 = row[0].1.unwrap().parse().unwrap();
                let att = att as u64;
                let def: i64 = row[1].1.unwrap().parse().unwrap();
                let def = def as u64;
                let flag: i32 = row[2].1.unwrap().parse().unwrap();
                let val: f32 = row[3].1.unwrap().parse().unwrap();
                ts.push(Transition {
                    board: (att as u128) | ((def as u128) << 64),
                    result: flag,
                    t_val: val,
                });
                true
            })
            .unwrap();
        return ts;
    }
}

pub fn random_rot(b: u128, id: usize) -> u128 {
    let id = id % 8;
    let mut b = b;
    if id < 4 {
        b = Board::hflip(b);
    }
    for i in 0..(id % 4) {
        b = Board::rot(b);
    }
    return b;
}

impl Iterator for BoardDB {
    type Item = (Tensor, Tensor);

    fn next(&mut self) -> Option<Self::Item> {
        if self.batch_num == 0 {
            return None;
        } else {
            let mut board = Vec::new();
            let mut result = Vec::new();

            let ts = self.get_batch();
            for t in &ts {
                // pprint_board(&u128_to_b(t.board));
                let res;
                if t.result == 1 {
                    res = 1.0;
                } else if t.result == -1 {
                    res = 0.0;
                } else {
                    res = 0.5;
                }
                // println!("res:{res}, val:{}", t.t_val);
                let rot_b = random_rot(t.board, get_random_usize());
                board.push(Tensor::new(u2vec(rot_b), vec![128, 1]));
                result.push(Tensor::new(
                    vec![res * train::LAMBDA + (1.0 - train::LAMBDA) * t.t_val],
                    vec![1, 1],
                ));
            }
            let board = create_batch(board);
            let result = create_batch(result);

            self.batch_num -= 1;

            return Some((board, result));
        }
    }
}