libsignal_rust/
session_record.rs

1use crate::base_key_type::BaseKeyType;
2use crate::chain_type::ChainType;
3use std::collections::HashMap;
4use serde::{Serialize, Deserialize};
5use base64::{Engine as _, engine::general_purpose};
6
7const CLOSED_SESSIONS_MAX: usize = 40;
8const SESSION_RECORD_VERSION: &str = "v1";
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ChainInfo {
12    pub message_keys: HashMap<u32, Vec<u8>>,
13    pub chain_key: ChainKey,
14    pub chain_type: ChainType,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ChainKey {
19    pub counter: i32,
20    pub key: Option<Vec<u8>>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct CurrentRatchet {
25    pub ephemeral_key_pair: crate::curve::KeyPair,
26    pub last_remote_ephemeral_key: Vec<u8>,
27    pub previous_counter: u32,
28    pub root_key: Vec<u8>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct IndexInfo {
33    pub created: u64,
34    pub used: u64,
35    pub remote_identity_key: Vec<u8>,
36    pub base_key: Vec<u8>,
37    pub base_key_type: BaseKeyType,
38    pub closed: i64,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct PendingPreKey {
43    pub signed_key_id: u32,
44    pub base_key: Vec<u8>,
45    pub pre_key_id: Option<u32>,
46}
47
48#[derive(Debug, Clone)]
49pub struct SessionEntry {
50    pub registration_id: u32,
51    pub current_ratchet: CurrentRatchet,
52    pub index_info: IndexInfo,
53    pub pending_pre_key: Option<PendingPreKey>,
54    chains: HashMap<String, ChainInfo>,
55}
56
57impl SessionEntry {
58    pub fn new() -> Self {
59        Self {
60            registration_id: 0,
61            current_ratchet: CurrentRatchet {
62                ephemeral_key_pair: crate::curve::KeyPair {
63                    priv_key: vec![],
64                    pub_key: vec![],
65                },
66                last_remote_ephemeral_key: vec![],
67                previous_counter: 0,
68                root_key: vec![],
69            },
70            index_info: IndexInfo {
71                created: 0,
72                used: 0,
73                remote_identity_key: vec![],
74                base_key: vec![],
75                base_key_type: BaseKeyType::Ours,
76                closed: -1,
77            },
78            pending_pre_key: None,
79            chains: HashMap::new(),
80        }
81    }
82
83    pub fn add_chain(&mut self, key: &[u8], value: ChainInfo) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
84        let id = general_purpose::STANDARD.encode(key);
85        if self.chains.contains_key(&id) {
86            return Err("Overwrite attempt".into());
87        }
88        self.chains.insert(id, value);
89        Ok(())
90    }
91
92    pub fn get_chain(&self, key: &[u8]) -> Option<&ChainInfo> {
93        let id = general_purpose::STANDARD.encode(key);
94        self.chains.get(&id)
95    }
96
97    pub fn get_chain_mut(&mut self, key: &[u8]) -> Option<&mut ChainInfo> {
98        let id = general_purpose::STANDARD.encode(key);
99        self.chains.get_mut(&id)
100    }
101
102    pub fn delete_chain(&mut self, key: &[u8]) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
103        let id = general_purpose::STANDARD.encode(key);
104        if !self.chains.contains_key(&id) {
105            return Err("Not Found".into());
106        }
107        self.chains.remove(&id);
108        Ok(())
109    }
110
111    pub fn chains(&self) -> impl Iterator<Item = (Vec<u8>, &ChainInfo)> {
112        self.chains.iter().map(|(k, v)| {
113            let key = general_purpose::STANDARD.decode(k).unwrap_or_default();
114            (key, v)
115        })
116    }
117}
118
119#[derive(Debug, Clone)]
120pub struct SessionRecord {
121    pub sessions: HashMap<String, SessionEntry>,
122    pub version: String,
123}
124
125impl SessionRecord {
126    pub fn new() -> Self {
127        Self {
128            sessions: HashMap::new(),
129            version: SESSION_RECORD_VERSION.to_string(),
130        }
131    }
132
133    pub fn create_entry() -> SessionEntry {
134        SessionEntry::new()
135    }
136
137    pub fn have_open_session(&self) -> bool {
138        if let Some(open_session) = self.get_open_session() {
139            open_session.registration_id > 0
140        } else {
141            false
142        }
143    }
144
145    pub fn get_session(&self, key: &[u8]) -> Option<&SessionEntry> {
146        let id = general_purpose::STANDARD.encode(key);
147        let session = self.sessions.get(&id);
148        if let Some(session) = session {
149            if session.index_info.base_key_type == BaseKeyType::Ours {
150                // Invalid operation cuzz cannot lookup session using our own base key
151                return None;
152            }
153        }
154        session
155    }
156
157    pub fn get_open_session(&self) -> Option<&SessionEntry> {
158        for session in self.sessions.values() {
159            if !self.is_closed(session) {
160                return Some(session);
161            }
162        }
163        None
164    }
165
166    pub fn set_session(&mut self, session: SessionEntry) {
167        let id = general_purpose::STANDARD.encode(&session.index_info.base_key);
168        self.sessions.insert(id, session);
169    }
170
171    pub fn get_sessions(&self) -> Vec<&SessionEntry> {
172        let mut sessions: Vec<&SessionEntry> = self.sessions.values().collect();
173        sessions.sort_by(|a, b| {
174            let a_used = a.index_info.used;
175            let b_used = b.index_info.used;
176            b_used.cmp(&a_used)
177        });
178        sessions
179    }
180
181    pub fn close_session(&mut self, session_key: &[u8]) {
182        let id = general_purpose::STANDARD.encode(session_key);
183        if let Some(session) = self.sessions.get_mut(&id) {
184            if session.index_info.closed != -1 {
185                return;
186            }
187            session.index_info.closed = chrono::Utc::now().timestamp();
188        }
189    }
190
191    pub fn open_session(&mut self, session_key: &[u8]) {
192        let id = general_purpose::STANDARD.encode(session_key);
193        if let Some(session) = self.sessions.get_mut(&id) {
194            session.index_info.closed = -1;
195        }
196    }
197
198    pub fn is_closed(&self, session: &SessionEntry) -> bool {
199        session.index_info.closed != -1
200    }
201
202    pub fn remove_old_sessions(&mut self) {
203        while self.sessions.len() > CLOSED_SESSIONS_MAX {
204            let mut oldest_key: Option<String> = None;
205            let mut oldest_closed: i64 = i64::MAX;
206
207            for (key, session) in &self.sessions {
208                if session.index_info.closed != -1 && session.index_info.closed < oldest_closed {
209                    oldest_key = Some(key.clone());
210                    oldest_closed = session.index_info.closed;
211                }
212            }
213
214            if let Some(key) = oldest_key {
215                self.sessions.remove(&key);
216            } else {
217                break;
218            }
219        }
220    }
221
222    pub fn delete_all_sessions(&mut self) {
223        self.sessions.clear();
224    }
225}