#![allow(dead_code)]
#![allow(unused_variables)]
#![allow(unused_imports)]
#![allow(non_fmt_panics)]
#![allow(unused_mut)]
#![allow(unused_assignments)]
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
#![allow(rustdoc::missing_crate_level_docs)]
#![allow(unsafe_code)]
#![allow(clippy::undocumented_unsafe_blocks)]
#![allow(unused_must_use)]
#![allow(non_snake_case)]
#![allow(clippy::upper_case_acronyms)]
use rand::{
distr::{
weighted::WeightedIndex, Distribution
},
rng, Rng,
};
use rand::prelude::{
IndexedRandom, IteratorRandom, SliceRandom
};
use serde::{Deserialize, Serialize};
use smallvec::ToSmallVec;
use bitvec::{order::Msb0, view::BitView};
use serde_yml::{to_writer, from_reader};
use flate2::Compression;
use flate2::write::GzEncoder;
use flate2::read::GzDecoder;
use std::fmt;
use std::fs::File;
use std::io::{self, Write, Read};
use std::path::Path;
use std::collections::HashMap;
use crate::{
BowResult, BoW, Desc, DirectIdx,
IdPath, BowErr,
};
use crate::vocabulary::{
Vocabulary, NodeId, Block,
Children,
};
#[derive(Serialize, Deserialize)]
pub struct DBoW3Node {
nodeId: u32,
parentId: u32,
weight: f32,
descriptor: String,
}
#[derive(Serialize, Deserialize)]
pub struct DBoW3Word {
wordId: u32,
nodeId: u32,
}
#[derive(Serialize, Deserialize)]
pub struct DBoW3Vocabulary {
k: u32,
L: u32,
scoringType: u32,
weightingType: u32,
nodes: Vec<DBoW3Node>,
words: Vec<DBoW3Word>,
}
impl Vocabulary {
pub fn save_dbow3<P: AsRef<std::path::Path>>(&self, file: P) -> BowResult<()> {
let file_path = file.as_ref();
let mut file = File::create(file_path)?;
if file_path.extension().and_then(|s| s.to_str()) == Some("gz") {
let mut gz = GzEncoder::new(file, Compression::default());
Self::generate_dbow3_yaml(self, &mut gz)?;
gz.finish()?;
} else {
Self::generate_dbow3_yaml(self, &mut file)?;
}
Ok(())
}
pub fn load_dbow3<P: AsRef<std::path::Path>>(file: P) -> BowResult<Self> {
let file_path = file.as_ref();
if file_path.extension().and_then(|s| s.to_str()) == Some("gz") {
let file = File::open(file_path)?;
let gz = GzDecoder::new(file);
Self::parse_dbow3_yaml(gz)
} else {
let file = File::open(file_path)?;
Self::parse_dbow3_yaml(file)
}
}
fn generate_dbow3_yaml(vocab: &Vocabulary, writer: &mut impl Write) -> BowResult<()> {
writeln!(writer, "%YAML 1.0")?;
writeln!(writer, "vocabulary:")?;
writeln!(writer, " k: {}", vocab.k)?;
writeln!(writer, " L: {}", vocab.levels)?;
writeln!(writer, " scoringType: 0")?;
writeln!(writer, " weightingType: 0")?;
writeln!(writer, " nodes:")?;
for block in &vocab.blocks {
for (i, child) in block.children.ids.iter().enumerate() {
if let NodeId::Block(id) = child {
writeln!(writer, " - {{ nodeId:{}, parentId:{}, weight:{}, descriptor:\"{}\" }}",
id, block.id.get_bid(), block.children.weights[i],
block.children.features[i].iter().map(|b| b.to_string()).collect::<Vec<_>>().join(" "))?;
}
}
}
writeln!(writer, " words:")?;
for (word_id, node_id) in vocab.blocks.iter().flat_map(|b| b.children.ids.iter().enumerate())
.filter(|(_, n)| matches!(n, NodeId::Leaf(_)))
.map(|(i, n)| (i, n.get_bid())) {
writeln!(writer, " - {{ wordId:{}, nodeId:{} }}", word_id, node_id)?;
}
Ok(())
}
fn parse_dbow3_yaml(mut reader: impl Read) -> BowResult<Vocabulary> {
let mut contents = String::new();
reader.read_to_string(&mut contents)?;
let mut vocab = Vocabulary::empty(10, 5);
let lines: Vec<&str> = contents.lines().collect();
let mut current_section = None;
let mut nodes = Vec::new();
let mut words = Vec::new();
let mut i = 0;
println!("开始解析YAML文件,总行数: {}", lines.len());
while i < lines.len() {
let line = lines[i].trim();
if line.is_empty() {
i += 1;
continue;
}
if line.starts_with("nodes:") {
current_section = Some("nodes");
println!("进入节点解析部分");
} else if line.starts_with("words:") {
current_section = Some("words");
println!("进入词解析部分");
} else if line.starts_with("- {") {
let content = line.split('{').nth(1).and_then(|s| s.split('}').next())
.unwrap_or("");
println!("解析内容: {{{}}}", content);
match current_section {
Some("nodes") => {
println!("解析节点信息,当前行: {}", i);
let mut node_content = String::new();
node_content.push_str(line.split('{').nth(1).unwrap_or(""));
i += 1;
while i < lines.len() && !lines[i].trim().ends_with('}') {
node_content.push_str(lines[i].trim());
i += 1;
}
if i < lines.len() {
node_content.push_str(lines[i].trim().split('}').next().unwrap_or(""));
}
let mut node_data = HashMap::new();
for pair in node_content.split(',') {
let kv: Vec<&str> = pair.split(':').collect();
if kv.len() == 2 {
node_data.insert(kv[0].trim(), kv[1].trim());
}
}
let node_id: usize = node_data.get("nodeId")
.ok_or(BowErr::ParseError("Missing nodeId".to_string()))?
.parse()?;
let parent_id: usize = node_data.get("parentId")
.ok_or(BowErr::ParseError("Missing parentId".to_string()))?
.parse()?;
let weight = node_data.get("weight")
.map(|w| w.parse().unwrap_or(1.0))
.unwrap_or(1.0);
let descriptor: Vec<u8> = node_data.get("descriptor")
.ok_or(BowErr::ParseError("Missing descriptor".to_string()))?
.trim_matches('"')
.split_whitespace()
.take(32) .map(|s| s.parse().map_err(|_| BowErr::ParseError("Invalid descriptor format".to_string())))
.collect::<Result<Vec<_>, _>>()?;
let mut descriptor = descriptor;
descriptor.resize(32, 0);
println!("成功解析节点: nodeId={}, parentId={}, weight={}, descriptor长度={}", node_id, parent_id, weight, descriptor.len());
nodes.push((node_id, parent_id, weight, descriptor));
}
Some("words") => {
println!("解析词信息,当前行: {}", i);
let word_content = line.split('{').nth(1).and_then(|s| s.split('}').next())
.ok_or(BowErr::ParseError("Invalid word format".to_string()))?;
let mut word_data = HashMap::new();
for pair in word_content.split(',') {
let kv: Vec<&str> = pair.split(':').collect();
if kv.len() == 2 {
word_data.insert(kv[0].trim(), kv[1].trim());
}
}
let word_id: usize = word_data.get("wordId")
.ok_or(BowErr::ParseError("Missing wordId".to_string()))?
.parse()?;
let node_id: usize = word_data.get("nodeId")
.ok_or(BowErr::ParseError("Missing nodeId".to_string()))?
.parse()?;
println!("成功解析词: wordId={}, nodeId={}", word_id, node_id);
words.push((word_id, node_id));
}
_ => {}
}
} else if line.starts_with("k:") {
vocab.k = line.split(':').nth(1).unwrap().trim().parse().unwrap();
println!("解析k值: {}", vocab.k);
} else if line.starts_with("L:") {
vocab.levels = line.split(':').nth(1).unwrap().trim().parse().unwrap();
println!("解析L值: {}", vocab.levels);
}
i += 1;
}
println!("开始构建词汇表结构,节点数: {}, 词数: {}", nodes.len(), words.len());
for (node_id, parent_id, weight, descriptor) in &nodes {
let block = Block {
id: NodeId::Block(*parent_id),
children: Children {
features: vec![descriptor.as_slice().try_into().expect("Descriptor length mismatch")],
weights: vec![*weight],
cluster_size: vec![1],
ids: vec![NodeId::Block(*node_id)],
},
};
vocab.blocks.push(block);
}
vocab.num_leaves = words.len();
vocab.num_blocks = nodes.len();
println!("词汇表解析完成,总词数: {}, 总节点数: {}", vocab.num_leaves, vocab.num_blocks);
Ok(vocab)
}
}
#[cfg(test)]
mod tests2 {
use super::*;
#[test]
fn test_save_and_load_dbow3() {
use crate::keypoint::load_img_get_kps;
use image::ImageBuffer;
use std::path::PathBuf;
let path = PathBuf::from("./assets/290.png");
let n_keypoints = 5;
let descs = load_img_get_kps(&path, n_keypoints).unwrap();
assert_eq!(descs.len(),n_keypoints);
let vocab = Vocabulary::create(&descs, 3, 3);
println!("vocab:{:?}", vocab);
let file_path_yml = format!("./result/vocab_dbow3_{}.yml",rand::random::<u32>() % 50 + 10);
assert!(vocab.save_dbow3(&file_path_yml).is_ok());
let loaded_vocab_yml = Vocabulary::load_dbow3(file_path_yml);
assert!(loaded_vocab_yml.is_ok());
let file_path_gz = format!("./result/vocab_dbow3_{}.yml.gz",rand::random::<u32>() % 50 + 10);
assert!(vocab.save_dbow3(&file_path_gz).is_ok());
let loaded_vocab_gz = Vocabulary::load_dbow3(file_path_gz);
assert!(loaded_vocab_gz.is_ok());
}
}