libsignal_rust/
session_record.rs1use crate::base_key_type::BaseKeyType;
2use crate::chain_type::ChainType;
3use std::collections::HashMap;
4use serde::{Serialize, Deserialize};
5use base64::{Engine as _, engine::general_purpose};
6
7const CLOSED_SESSIONS_MAX: usize = 40;
8const SESSION_RECORD_VERSION: &str = "v1";
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ChainInfo {
12 pub message_keys: HashMap<u32, Vec<u8>>,
13 pub chain_key: ChainKey,
14 pub chain_type: ChainType,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ChainKey {
19 pub counter: i32,
20 pub key: Option<Vec<u8>>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct CurrentRatchet {
25 pub ephemeral_key_pair: crate::curve::KeyPair,
26 pub last_remote_ephemeral_key: Vec<u8>,
27 pub previous_counter: u32,
28 pub root_key: Vec<u8>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct IndexInfo {
33 pub created: u64,
34 pub used: u64,
35 pub remote_identity_key: Vec<u8>,
36 pub base_key: Vec<u8>,
37 pub base_key_type: BaseKeyType,
38 pub closed: i64,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct PendingPreKey {
43 pub signed_key_id: u32,
44 pub base_key: Vec<u8>,
45 pub pre_key_id: Option<u32>,
46}
47
48#[derive(Debug, Clone)]
49pub struct SessionEntry {
50 pub registration_id: u32,
51 pub current_ratchet: CurrentRatchet,
52 pub index_info: IndexInfo,
53 pub pending_pre_key: Option<PendingPreKey>,
54 chains: HashMap<String, ChainInfo>,
55}
56
57impl SessionEntry {
58 pub fn new() -> Self {
59 Self {
60 registration_id: 0,
61 current_ratchet: CurrentRatchet {
62 ephemeral_key_pair: crate::curve::KeyPair {
63 priv_key: vec![],
64 pub_key: vec![],
65 },
66 last_remote_ephemeral_key: vec![],
67 previous_counter: 0,
68 root_key: vec![],
69 },
70 index_info: IndexInfo {
71 created: 0,
72 used: 0,
73 remote_identity_key: vec![],
74 base_key: vec![],
75 base_key_type: BaseKeyType::Ours,
76 closed: -1,
77 },
78 pending_pre_key: None,
79 chains: HashMap::new(),
80 }
81 }
82
83 pub fn add_chain(&mut self, key: &[u8], value: ChainInfo) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
84 let id = general_purpose::STANDARD.encode(key);
85 if self.chains.contains_key(&id) {
86 return Err("Overwrite attempt".into());
87 }
88 self.chains.insert(id, value);
89 Ok(())
90 }
91
92 pub fn get_chain(&self, key: &[u8]) -> Option<&ChainInfo> {
93 let id = general_purpose::STANDARD.encode(key);
94 self.chains.get(&id)
95 }
96
97 pub fn get_chain_mut(&mut self, key: &[u8]) -> Option<&mut ChainInfo> {
98 let id = general_purpose::STANDARD.encode(key);
99 self.chains.get_mut(&id)
100 }
101
102 pub fn delete_chain(&mut self, key: &[u8]) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
103 let id = general_purpose::STANDARD.encode(key);
104 if !self.chains.contains_key(&id) {
105 return Err("Not Found".into());
106 }
107 self.chains.remove(&id);
108 Ok(())
109 }
110
111 pub fn chains(&self) -> impl Iterator<Item = (Vec<u8>, &ChainInfo)> {
112 self.chains.iter().map(|(k, v)| {
113 let key = general_purpose::STANDARD.decode(k).unwrap_or_default();
114 (key, v)
115 })
116 }
117}
118
119#[derive(Debug, Clone)]
120pub struct SessionRecord {
121 pub sessions: HashMap<String, SessionEntry>,
122 pub version: String,
123}
124
125impl SessionRecord {
126 pub fn new() -> Self {
127 Self {
128 sessions: HashMap::new(),
129 version: SESSION_RECORD_VERSION.to_string(),
130 }
131 }
132
133 pub fn create_entry() -> SessionEntry {
134 SessionEntry::new()
135 }
136
137 pub fn have_open_session(&self) -> bool {
138 if let Some(open_session) = self.get_open_session() {
139 open_session.registration_id > 0
140 } else {
141 false
142 }
143 }
144
145 pub fn get_session(&self, key: &[u8]) -> Option<&SessionEntry> {
146 let id = general_purpose::STANDARD.encode(key);
147 let session = self.sessions.get(&id);
148 if let Some(session) = session {
149 if session.index_info.base_key_type == BaseKeyType::Ours {
150 return None;
152 }
153 }
154 session
155 }
156
157 pub fn get_open_session(&self) -> Option<&SessionEntry> {
158 for session in self.sessions.values() {
159 if !self.is_closed(session) {
160 return Some(session);
161 }
162 }
163 None
164 }
165
166 pub fn set_session(&mut self, session: SessionEntry) {
167 let id = general_purpose::STANDARD.encode(&session.index_info.base_key);
168 self.sessions.insert(id, session);
169 }
170
171 pub fn get_sessions(&self) -> Vec<&SessionEntry> {
172 let mut sessions: Vec<&SessionEntry> = self.sessions.values().collect();
173 sessions.sort_by(|a, b| {
174 let a_used = a.index_info.used;
175 let b_used = b.index_info.used;
176 b_used.cmp(&a_used)
177 });
178 sessions
179 }
180
181 pub fn close_session(&mut self, session_key: &[u8]) {
182 let id = general_purpose::STANDARD.encode(session_key);
183 if let Some(session) = self.sessions.get_mut(&id) {
184 if session.index_info.closed != -1 {
185 return;
186 }
187 session.index_info.closed = chrono::Utc::now().timestamp();
188 }
189 }
190
191 pub fn open_session(&mut self, session_key: &[u8]) {
192 let id = general_purpose::STANDARD.encode(session_key);
193 if let Some(session) = self.sessions.get_mut(&id) {
194 session.index_info.closed = -1;
195 }
196 }
197
198 pub fn is_closed(&self, session: &SessionEntry) -> bool {
199 session.index_info.closed != -1
200 }
201
202 pub fn remove_old_sessions(&mut self) {
203 while self.sessions.len() > CLOSED_SESSIONS_MAX {
204 let mut oldest_key: Option<String> = None;
205 let mut oldest_closed: i64 = i64::MAX;
206
207 for (key, session) in &self.sessions {
208 if session.index_info.closed != -1 && session.index_info.closed < oldest_closed {
209 oldest_key = Some(key.clone());
210 oldest_closed = session.index_info.closed;
211 }
212 }
213
214 if let Some(key) = oldest_key {
215 self.sessions.remove(&key);
216 } else {
217 break;
218 }
219 }
220 }
221
222 pub fn delete_all_sessions(&mut self) {
223 self.sessions.clear();
224 }
225}