1use super::State;
2use crate::error::Error;
3use crate::{
4 Action, CallbackType, ConnectionAction, ConnectionCallbackType, EntryData, EntryValue,
5 RpcCallback,
6};
7use futures_channel::mpsc::{channel, unbounded, Receiver, Sender, UnboundedSender};
8use futures_util::StreamExt;
9use multimap::MultiMap;
10use nt_network::{
11 ClearAllEntries, EntryAssignment, EntryDelete, EntryFlagsUpdate, EntryUpdate, Packet,
12 RpcExecute,
13};
14use std::collections::HashMap;
15use std::net::SocketAddr;
16use std::sync::{Arc, Mutex};
17use std::thread;
18use tokio::runtime::Runtime;
19
20pub(crate) mod conn;
21
22pub struct ClientState {
23 pub(crate) connected: bool,
24 ip: String,
25 name: String,
26 entries: HashMap<u16, EntryData>,
27 callbacks: MultiMap<CallbackType, Box<Action>>,
28 connection_callbacks: MultiMap<ConnectionCallbackType, Box<ConnectionAction>>,
29 pub(crate) pending_entries: HashMap<String, Sender<u16>>,
30 pub(crate) packet_tx: UnboundedSender<Box<dyn Packet>>,
31 rpc_callbacks: HashMap<u16, Box<RpcCallback>>,
32 next_rpc_id: u16,
33}
34
35impl ClientState {
36 pub async fn new(ip: String, name: String, close_rx: Receiver<()>) -> Arc<Mutex<ClientState>> {
37 let (packet_tx, packet_rx) = unbounded::<Box<dyn Packet>>();
38 let (ready_tx, mut ready_rx) = unbounded::<()>();
39
40 let state = Arc::new(Mutex::new(ClientState {
41 connected: false,
42 ip,
43 name,
44 entries: HashMap::new(),
45 callbacks: MultiMap::new(),
46 connection_callbacks: MultiMap::new(),
47 pending_entries: HashMap::new(),
48 packet_tx,
49 rpc_callbacks: HashMap::new(),
50 next_rpc_id: 0,
51 }));
52
53 let rt_state = state.clone();
54 thread::spawn(move || {
55 let mut rt = Runtime::new().unwrap();
56 rt.block_on(conn::connection(rt_state, packet_rx, ready_tx, close_rx))
57 .unwrap();
58 });
59
60 ready_rx.next().await;
61 state
62 }
63
64 #[cfg(feature = "websocket")]
65 pub async fn new_ws(
66 url: String,
67 name: String,
68 close_rx: Receiver<()>,
69 ) -> crate::Result<Arc<Mutex<ClientState>>> {
70 let (packet_tx, packet_rx) = unbounded::<Box<dyn Packet>>();
71 let (ready_tx, mut ready_rx) = unbounded::<()>();
72
73 let state = Arc::new(Mutex::new(ClientState {
74 connected: false,
75 ip: url,
76 name,
77 entries: HashMap::new(),
78 callbacks: MultiMap::new(),
79 connection_callbacks: MultiMap::new(),
80 pending_entries: HashMap::new(),
81 packet_tx,
82 rpc_callbacks: HashMap::new(),
83 next_rpc_id: 0,
84 }));
85
86 let rt_state = state.clone();
87 thread::spawn(move || {
88 let mut rt = Runtime::new().unwrap();
89
90 let _ = rt.block_on(conn::connection_ws(rt_state, packet_rx, ready_tx, close_rx));
91 });
92
93 if let None = ready_rx.next().await {
94 return Err(Error::ConnectionAborted);
95 }
96 Ok(state)
97 }
98
99 pub fn add_connection_callback(
100 &mut self,
101 callback_type: ConnectionCallbackType,
102 action: impl FnMut(&SocketAddr) + Send + 'static,
103 ) {
104 self.connection_callbacks
105 .insert(callback_type, Box::new(action));
106 }
107
108 pub fn call_rpc(
109 &mut self,
110 id: u16,
111 parameter: Vec<u8>,
112 callback: impl Fn(Vec<u8>) + Send + 'static,
113 ) {
114 self.rpc_callbacks
115 .insert(self.next_rpc_id, Box::new(callback));
116 self.packet_tx
117 .unbounded_send(Box::new(RpcExecute::new(id, self.next_rpc_id, parameter)))
118 .unwrap();
119
120 self.next_rpc_id += 1;
121 }
122}
123
124impl State for ClientState {
125 fn entries(&self) -> &HashMap<u16, EntryData> {
126 &self.entries
127 }
128
129 fn entries_mut(&mut self) -> &mut HashMap<u16, EntryData> {
130 &mut self.entries
131 }
132
133 fn create_entry(&mut self, data: EntryData) -> crate::Result<Receiver<u16>> {
134 if !self.connected {
135 return Err(Error::BrokenPipe);
136 }
137 let (tx, rx) = channel::<u16>(1);
138 self.pending_entries.insert(data.name.clone(), tx);
139 self.packet_tx
140 .unbounded_send(Box::new(EntryAssignment::new(
141 data.name.clone(),
142 data.entry_type(),
143 0xFFFF,
144 data.seqnum,
145 data.flags,
146 data.value,
147 )))
148 .unwrap();
149 Ok(rx)
150 }
151
152 fn delete_entry(&mut self, id: u16) {
153 let packet = EntryDelete::new(id);
154 self.packet_tx.unbounded_send(Box::new(packet)).unwrap();
155 }
156
157 fn update_entry(&mut self, id: u16, new_value: EntryValue) {
158 if let Some(entry) = self.entries.get_mut(&id) {
159 entry.value = new_value.clone();
160 entry.seqnum += 1;
161 self.packet_tx
162 .unbounded_send(Box::new(EntryUpdate::new(
163 id,
164 entry.seqnum,
165 entry.entry_type(),
166 new_value,
167 )))
168 .unwrap();
169 }
170 }
171
172 fn update_entry_flags(&mut self, id: u16, flags: u8) {
173 if let Some(entry) = self.entries.get_mut(&id) {
174 entry.flags = flags;
175 self.packet_tx
176 .unbounded_send(Box::new(EntryFlagsUpdate::new(id, flags)))
177 .unwrap();
178 }
179 }
180
181 fn clear_entries(&mut self) {
182 self.packet_tx
183 .unbounded_send(Box::new(ClearAllEntries::new()))
184 .unwrap();
185 self.entries.clear();
186 }
187
188 fn add_callback(
189 &mut self,
190 callback_type: CallbackType,
191 action: impl FnMut(&EntryData) + Send + 'static,
192 ) {
193 self.callbacks.insert(callback_type, Box::new(action));
194 }
195}