use std::{collections::HashMap, path::PathBuf};
use anyhow::Result;
use itertools::Itertools;
use rusqlite::{Connection as RusqliteConnection, named_params, params_from_iter};
use tokio_rusqlite_new::Connection;
use crate::vrchat::VRCHAT_LOW_PATH;
pub struct Cache {
connection: Connection,
}
pub type AvatarIDWithProvider<S> = (S, u32);
impl Cache {
pub async fn new() -> Result<Self> {
debug!("Trying to open SQLite cache database.");
Self::new_at_location(&VRCHAT_LOW_PATH.join("avatars.sqlite")).await
}
pub async fn new_at_location(path: &PathBuf) -> Result<Self> {
debug!("Trying to open SQLite cache database.");
let connection = Connection::open(path).await?;
connection
.call(|connection| Self::setup_database(connection))
.await?;
Ok(Self { connection })
}
pub async fn new_in_memory() -> Result<Self> {
let connection = Connection::open_in_memory().await?;
connection
.call(|connection| Self::setup_database(connection))
.await?;
Ok(Self { connection })
}
fn setup_database(connection: &RusqliteConnection) -> Result<(), rusqlite::Error> {
let query = "CREATE TABLE avatars (
id TEXT PRIMARY KEY,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
provider_bits INT DEFAULT 0
)";
debug!("Trying to create avatars table...");
if connection.execute(query, []).is_err() {
debug!("The avatars table already exists.");
let mut statement = connection.prepare("PRAGMA table_info(avatars)")?;
let columns = statement
.query_map([], |row| row.get::<_, String>(1))?
.collect::<Result<Vec<_>, _>>()?;
if !columns.contains(&"created_at".to_string()) {
debug!("Trying to create the created_at column.");
#[rustfmt::skip]
connection.execute("
ALTER TABLE avatars
ADD COLUMN created_at DATETIME
", [])?;
}
debug!("Updating all rows with missing created_at");
#[rustfmt::skip]
connection.execute("
UPDATE avatars
SET created_at = CURRENT_TIMESTAMP
WHERE created_at IS NULL
", [])?;
if !columns.contains(&"updated_at".to_string()) {
debug!("Trying to create the updated_at column.");
#[rustfmt::skip]
connection.execute("
ALTER TABLE avatars
ADD COLUMN updated_at DATETIME
", [])?;
}
if !columns.contains(&"provider_bits".to_string()) {
debug!("Trying to create the provider_bits column.");
#[rustfmt::skip]
connection.execute("
ALTER TABLE avatars
ADD COLUMN provider_bits INT DEFAULT 0
", [])?;
}
debug!("Updating all rows with missing updated_at");
#[rustfmt::skip]
connection.execute("
UPDATE avatars
SET updated_at = datetime('now', '-31 days')
WHERE updated_at IS NULL
", [])?;
}
debug!("Trying to create an updated_at index.");
#[rustfmt::skip]
connection.execute("
CREATE INDEX IF NOT EXISTS idx_avatars_updated_at
ON avatars(updated_at)
", [])?;
debug!("Trying to create an id index.");
#[rustfmt::skip]
connection.execute(
"
CREATE INDEX IF NOT EXISTS idx_avatars_id
ON avatars(id)
", [])?;
if let Ok(mut statement) = connection.prepare("SELECT COUNT(*) FROM avatars")
&& let Ok(count) = statement.query_row([], |row| row.get::<_, i64>(0))
{
info!("{} Cached Avatars", count);
}
Ok(())
}
pub async fn store_avatar_ids_with_providers<
S: ToString,
I: IntoIterator<Item = AvatarIDWithProvider<S>>,
>(
&self,
insertables: I,
) -> Result<()> {
let query = "
INSERT INTO avatars (id, provider_bits, created_at, updated_at)
VALUES (:id, :providers, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
ON CONFLICT (id) DO UPDATE
SET updated_at = CURRENT_TIMESTAMP,
provider_bits = :providers
";
let insertables: Vec<_> = insertables
.into_iter()
.map(|a| (a.0.to_string(), a.1))
.collect();
self.connection
.call(|c| -> Result<(), rusqlite::Error> {
let tx = c.transaction()?;
for (id, providers) in insertables {
let _ = tx.execute(
query,
named_params! {
":id": id,
":providers": providers
},
);
}
tx.commit()
})
.await
.map_err(anyhow::Error::from)
}
const CHUNK_SIZE: usize = 950;
pub async fn check_all_ids<I: IntoIterator<Item = String>>(
&self,
ids: I,
) -> Result<HashMap<String, u32>> {
let ids: Vec<_> = ids.into_iter().collect();
self.connection
.call(|c| -> Result<_, rusqlite::Error> {
let mut output = HashMap::new();
for chunk in &ids.into_iter().chunks(Self::CHUNK_SIZE) {
let chunk: Vec<String> = chunk.collect();
for id in &chunk {
output.insert(id.clone(), 0);
}
let found_ids = Self::check_batch_ids(c, chunk.into_iter())?;
output.extend(found_ids);
}
Ok(output)
})
.await
.map_err(|e| anyhow::anyhow!(e))
}
fn check_batch_ids<I: Iterator<Item = String>>(
conn: &RusqliteConnection,
chunk: I,
) -> std::result::Result<HashMap<String, u32>, rusqlite::Error> {
let chunk: Vec<_> = chunk.collect();
assert!(chunk.len() <= Self::CHUNK_SIZE);
let placeholders = std::iter::repeat_n("?", chunk.len())
.collect::<Vec<_>>()
.join(",");
let sql = format!(
"SELECT id, provider_bits FROM avatars WHERE id IN ({placeholders}) AND updated_at >= datetime('now', '-30 days')"
);
let mut stmt = conn.prepare(&sql)?;
stmt.query_map(params_from_iter(chunk.iter()), |row| {
Ok((row.get::<_, String>(0)?, row.get::<_, u32>(1)?))
})?
.collect::<Result<HashMap<_, _>, _>>()
}
}
mod tests {
use super::Cache;
#[allow(dead_code)]
async fn cache() -> Cache {
Cache::new_in_memory().await.unwrap()
}
#[tokio::test]
async fn creates_database_and_table() {
let cache = cache().await;
cache
.store_avatar_ids_with_providers(vec![("avatar_1", 1u32)].into_iter())
.await
.unwrap();
}
#[tokio::test]
async fn inserts_and_reads_avatar_ids() {
let cache = cache().await;
cache
.store_avatar_ids_with_providers(
vec![("avatar_a", 1u32), ("avatar_b", 2u32)].into_iter(),
)
.await
.unwrap();
let result = cache
.check_all_ids(vec!["avatar_a".into(), "avatar_b".into()])
.await
.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result["avatar_a"], 1);
assert_eq!(result["avatar_b"], 2);
}
#[tokio::test]
async fn returns_none_for_missing_ids() {
let cache = cache().await;
let result = cache
.check_all_ids(vec!["missing_avatar".into()])
.await
.unwrap();
assert_eq!(result["missing_avatar"], 0);
}
#[tokio::test]
async fn updates_provider_bits_on_conflict() {
let cache = cache().await;
cache
.store_avatar_ids_with_providers(vec![("avatar_x", 1u32)].into_iter())
.await
.unwrap();
cache
.store_avatar_ids_with_providers(vec![("avatar_x", 42u32)].into_iter())
.await
.unwrap();
let result = cache.check_all_ids(vec!["avatar_x".into()]).await.unwrap();
assert_eq!(result["avatar_x"], 42);
}
#[tokio::test]
async fn respects_chunking_limits() {
let cache = cache().await;
#[allow(clippy::cast_possible_truncation)]
let ids: Vec<(String, u32)> = (0..(Cache::CHUNK_SIZE + 10))
.map(|i| (format!("avatar_{i}"), i as u32))
.collect();
cache
.store_avatar_ids_with_providers(ids.iter().map(|(id, p)| (id.as_str(), *p)))
.await
.unwrap();
let result = cache
.check_all_ids(ids.iter().map(|(id, _)| id.clone()))
.await
.unwrap();
assert_eq!(result.len(), ids.len());
for (id, provider) in ids {
assert_eq!(result[&id], provider);
}
}
#[tokio::test]
async fn ignores_entries_older_than_30_days() {
let cache = cache().await;
cache
.store_avatar_ids_with_providers(vec![("old_avatar", 1u32)].into_iter())
.await
.unwrap();
cache
.connection
.call(|c| {
c.execute(
"UPDATE avatars
SET updated_at = datetime('now', '-31 days')
WHERE id = 'old_avatar'",
[],
)
})
.await
.unwrap();
let result = cache
.check_all_ids(vec!["old_avatar".into()])
.await
.unwrap();
assert_eq!(result["old_avatar"], 0);
}
}