mysql/conn/
stmt_cache.rs

1// Copyright (c) 2020 rust-mysql-simple contributors
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9use lru::LruCache;
10use twox_hash::XxHash;
11
12use std::{
13    borrow::Borrow,
14    collections::HashMap,
15    hash::{BuildHasherDefault, Hash},
16    sync::Arc,
17};
18
19use crate::conn::stmt::InnerStmt;
20
21#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
22pub struct QueryString(pub Arc<Vec<u8>>);
23
24impl Borrow<[u8]> for QueryString {
25    fn borrow(&self) -> &[u8] {
26        &**self.0.as_ref()
27    }
28}
29
30impl PartialEq<[u8]> for QueryString {
31    fn eq(&self, other: &[u8]) -> bool {
32        &**self.0.as_ref() == other
33    }
34}
35
36pub struct Entry {
37    pub stmt: Arc<InnerStmt>,
38    pub query: QueryString,
39}
40
41#[derive(Debug)]
42pub struct StmtCache {
43    cap: usize,
44    cache: LruCache<u32, Entry>,
45    query_map: HashMap<QueryString, u32, BuildHasherDefault<XxHash>>,
46}
47
48impl StmtCache {
49    pub fn new(cap: usize) -> StmtCache {
50        StmtCache {
51            cap,
52            cache: LruCache::unbounded(),
53            query_map: Default::default(),
54        }
55    }
56
57    pub fn contains_query<T>(&self, key: &T) -> bool
58    where
59        QueryString: Borrow<T>,
60        T: Hash + Eq,
61        T: ?Sized,
62    {
63        self.query_map.contains_key(key)
64    }
65
66    pub fn by_query<T>(&mut self, query: &T) -> Option<&Entry>
67    where
68        QueryString: Borrow<T>,
69        QueryString: PartialEq<T>,
70        T: Hash + Eq,
71        T: ?Sized,
72    {
73        let id = self.query_map.get(query).cloned();
74        match id {
75            Some(id) => self.cache.get(&id),
76            None => None,
77        }
78    }
79
80    pub fn put(&mut self, query: Arc<Vec<u8>>, stmt: Arc<InnerStmt>) -> Option<Arc<InnerStmt>> {
81        if self.cap == 0 {
82            return None;
83        }
84
85        let query = QueryString(query);
86
87        self.query_map.insert(query.clone(), stmt.id());
88        self.cache.put(stmt.id(), Entry { stmt, query });
89
90        if self.cache.len() > self.cap {
91            if let Some((_, entry)) = self.cache.pop_lru() {
92                self.query_map.remove(&**entry.query.0.as_ref());
93                return Some(entry.stmt);
94            }
95        }
96
97        None
98    }
99
100    pub fn clear(&mut self) {
101        self.query_map.clear();
102        self.cache.clear();
103    }
104
105    pub fn remove(&mut self, id: u32) {
106        if let Some(entry) = self.cache.pop(&id) {
107            self.query_map.remove::<[u8]>(entry.query.borrow());
108        }
109    }
110
111    #[cfg(test)]
112    pub fn iter(&self) -> impl Iterator<Item = (&u32, &Entry)> {
113        self.cache.iter()
114    }
115
116    pub fn into_iter(mut self) -> impl Iterator<Item = (u32, Entry)> {
117        std::iter::from_fn(move || self.cache.pop_lru())
118    }
119}