1use crate::auth::{AuthContext, AuthCredentials};
4use crate::connection::{Connection, ConnectionState};
5use crate::error::{CommyError, Result};
6use crate::message::{ClientMessage, ServerMessage};
7use crate::service::Service;
8use crate::state::{create_shared_state, SharedState};
9use crate::virtual_file::VirtualVariableFile;
10use crate::watcher::VariableFileWatcher;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::RwLock;
15use uuid::Uuid;
16
17pub struct Client {
19 client_id: String,
21
22 server_url: String,
24
25 connection: Arc<RwLock<Option<Connection>>>,
27
28 state: SharedState,
30
31 heartbeat_interval: Duration,
33
34 max_reconnect_attempts: u32,
36
37 reconnect_attempts: Arc<AtomicU64>,
39
40 virtual_files: Arc<RwLock<std::collections::HashMap<String, Arc<VirtualVariableFile>>>>,
42
43 file_watcher: Arc<RwLock<Option<Arc<VariableFileWatcher>>>>,
45
46 heartbeat_task: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
48}
49
50impl Client {
51 pub async fn initialize(
62 server_url: impl Into<String>,
63 tenant_id: impl Into<String>,
64 credentials: AuthCredentials,
65 ) -> Result<Self> {
66 let client = Self::_new(server_url);
67 client._connect_impl().await?;
68 client
69 ._authenticate_impl(&tenant_id.into(), credentials)
70 .await?;
71 client._init_file_watcher_impl().await?;
72 client._start_file_monitoring_impl().await?;
73 Ok(client)
74 }
75
76 #[inline]
78 fn _new(server_url: impl Into<String>) -> Self {
79 let client_id = Uuid::new_v4().to_string();
80
81 Self {
82 client_id: client_id.clone(),
83 server_url: server_url.into(),
84 connection: Arc::new(RwLock::new(None)),
85 state: create_shared_state(client_id),
86 heartbeat_interval: Duration::from_secs(30),
87 max_reconnect_attempts: 5,
88 reconnect_attempts: Arc::new(AtomicU64::new(0)),
89 virtual_files: Arc::new(RwLock::new(std::collections::HashMap::new())),
90 file_watcher: Arc::new(RwLock::new(None)),
91 heartbeat_task: Arc::new(RwLock::new(None)),
92 }
93 }
94
95 pub fn new(server_url: impl Into<String>) -> Self {
97 Self::_new(server_url)
98 }
99
100 pub fn with_id(server_url: impl Into<String>, client_id: impl Into<String>) -> Self {
102 let client_id = client_id.into();
103
104 Self {
105 client_id: client_id.clone(),
106 server_url: server_url.into(),
107 connection: Arc::new(RwLock::new(None)),
108 state: create_shared_state(client_id),
109 heartbeat_interval: Duration::from_secs(30),
110 max_reconnect_attempts: 5,
111 reconnect_attempts: Arc::new(AtomicU64::new(0)),
112 virtual_files: Arc::new(RwLock::new(std::collections::HashMap::new())),
113 file_watcher: Arc::new(RwLock::new(None)),
114 heartbeat_task: Arc::new(RwLock::new(None)),
115 }
116 }
117
118 pub fn id(&self) -> &str {
120 &self.client_id
121 }
122
123 pub fn server_url(&self) -> &str {
125 &self.server_url
126 }
127
128 #[inline]
130 async fn _connect_impl(&self) -> Result<()> {
131 let mut state = self.state.write().await;
132 state.connection_state = ConnectionState::Connecting;
133 drop(state);
134
135 match Connection::new(&self.server_url).await {
136 Ok(conn) => {
137 let mut state = self.state.write().await;
138 state.connection_state = ConnectionState::Connected;
139 state.session_id = Some(Uuid::new_v4().to_string());
140 drop(state);
141
142 let mut conn_guard = self.connection.write().await;
143 *conn_guard = Some(conn);
144
145 self.reconnect_attempts.store(0, Ordering::SeqCst);
147
148 Ok(())
156 }
157 Err(e) => {
158 let mut state = self.state.write().await;
159 state.connection_state = ConnectionState::Disconnected;
160 Err(e)
161 }
162 }
163 }
164
165 pub async fn connect(&self) -> Result<()> {
167 self._connect_impl().await
168 }
169
170 #[inline]
172 async fn _authenticate_impl(
173 &self,
174 tenant_id: &str,
175 credentials: AuthCredentials,
176 ) -> Result<AuthContext> {
177 self.send_message(ClientMessage::Authenticate {
178 tenant_id: tenant_id.to_string(),
179 client_version: env!("CARGO_PKG_VERSION").to_string(),
180 credentials: credentials.clone(),
181 })
182 .await?;
183
184 if let Some(conn) = &*self.connection.read().await {
186 if let Ok(Ok(Some(ServerMessage::AuthenticationResult {
187 success,
188 permissions,
189 ..
190 }))) = tokio::time::timeout(Duration::from_secs(10), conn.recv()).await
191 {
192 if success {
193 let auth_context =
194 AuthContext::new(tenant_id.to_string(), permissions.unwrap_or_default());
195
196 let mut state = self.state.write().await;
197 state.connection_state = ConnectionState::Authenticated;
198 state.add_auth_context(tenant_id.to_string(), auth_context.clone());
199
200 Ok(auth_context)
201 } else {
202 Err(CommyError::AuthenticationFailed(
203 "Authentication denied by server".to_string(),
204 ))
205 }
206 } else {
207 Err(CommyError::Timeout)
208 }
209 } else {
210 Err(CommyError::ConnectionLost(
211 "Connection not established".to_string(),
212 ))
213 }
214 }
215
216 pub async fn authenticate(
218 &self,
219 tenant_id: impl Into<String>,
220 credentials: AuthCredentials,
221 ) -> Result<AuthContext> {
222 self._authenticate_impl(&tenant_id.into(), credentials)
223 .await
224 }
225
226 pub async fn create_service(&self, tenant_id: &str, service_name: &str) -> Result<String> {
233 let state = self.state.read().await;
235 if !state.is_authenticated_to(tenant_id) {
236 return Err(CommyError::PermissionDenied(format!(
237 "Not authenticated to tenant: {}",
238 tenant_id
239 )));
240 }
241 drop(state);
242
243 self.send_message(ClientMessage::CreateService {
245 tenant_id: tenant_id.to_string(),
246 service_name: service_name.to_string(),
247 })
248 .await?;
249
250 if let Some(conn) = &*self.connection.read().await {
252 if let Ok(Ok(Some(ServerMessage::Service { service_id, .. }))) =
253 tokio::time::timeout(Duration::from_secs(10), conn.recv()).await
254 {
255 Ok(service_id)
256 } else {
257 Err(CommyError::Timeout)
258 }
259 } else {
260 Err(CommyError::ConnectionLost(
261 "Connection lost during create_service".to_string(),
262 ))
263 }
264 }
265
266 pub async fn get_service(&self, tenant_id: &str, service_name: &str) -> Result<Service> {
273 let state = self.state.read().await;
275 if !state.is_authenticated_to(tenant_id) {
276 return Err(CommyError::PermissionDenied(format!(
277 "Not authenticated to tenant: {}",
278 tenant_id
279 )));
280 }
281 drop(state);
282
283 self.send_message(ClientMessage::GetService {
285 tenant_id: tenant_id.to_string(),
286 service_name: service_name.to_string(),
287 })
288 .await?;
289
290 if let Some(conn) = &*self.connection.read().await {
292 if let Ok(Ok(Some(ServerMessage::Service {
293 service_id,
294 service_name,
295 tenant_id: resp_tenant,
296 file_path,
297 }))) = tokio::time::timeout(Duration::from_secs(10), conn.recv()).await
298 {
299 let service = Service::new(service_id, service_name, resp_tenant, file_path);
300 Ok(service)
301 } else {
302 Err(CommyError::Timeout)
303 }
304 } else {
305 Err(CommyError::ConnectionLost(
306 "Connection lost during get_service".to_string(),
307 ))
308 }
309 }
310
311 pub async fn delete_service(&self, tenant_id: &str, service_name: &str) -> Result<()> {
318 let state = self.state.read().await;
320 if !state.is_authenticated_to(tenant_id) {
321 return Err(CommyError::PermissionDenied(format!(
322 "Not authenticated to tenant: {}",
323 tenant_id
324 )));
325 }
326 drop(state);
327
328 self.send_message(ClientMessage::DeleteService {
330 tenant_id: tenant_id.to_string(),
331 service_name: service_name.to_string(),
332 })
333 .await?;
334
335 if let Some(conn) = &*self.connection.read().await {
337 if let Ok(Ok(Some(ServerMessage::Result { success: true, .. }))) =
338 tokio::time::timeout(Duration::from_secs(10), conn.recv()).await
339 {
340 Ok(())
341 } else {
342 Err(CommyError::Timeout)
343 }
344 } else {
345 Err(CommyError::ConnectionLost(
346 "Connection lost during delete_service".to_string(),
347 ))
348 }
349 }
350
351 pub async fn create_tenant(&self, tenant_id: &str, tenant_name: &str) -> Result<String> {
361 self.send_message(ClientMessage::CreateTenant {
363 tenant_id: tenant_id.to_string(),
364 tenant_name: tenant_name.to_string(),
365 })
366 .await?;
367
368 if let Some(conn) = &*self.connection.read().await {
370 if let Ok(Ok(Some(ServerMessage::TenantResult {
371 success: true,
372 tenant_id: returned_id,
373 ..
374 }))) = tokio::time::timeout(Duration::from_secs(10), conn.recv()).await
375 {
376 Ok(returned_id)
377 } else {
378 Err(CommyError::Timeout)
379 }
380 } else {
381 Err(CommyError::ConnectionLost(
382 "Connection lost during create_tenant".to_string(),
383 ))
384 }
385 }
386
387 pub async fn delete_tenant(&self, tenant_id: &str) -> Result<()> {
398 self.send_message(ClientMessage::DeleteTenant {
400 tenant_id: tenant_id.to_string(),
401 })
402 .await?;
403
404 if let Some(conn) = &*self.connection.read().await {
406 if let Ok(Ok(Some(ServerMessage::Result { success: true, .. }))) =
407 tokio::time::timeout(Duration::from_secs(10), conn.recv()).await
408 {
409 Ok(())
410 } else {
411 Err(CommyError::Timeout)
412 }
413 } else {
414 Err(CommyError::ConnectionLost(
415 "Connection lost during delete_tenant".to_string(),
416 ))
417 }
418 }
419
420 pub async fn read_variable(&self, service_id: &str, variable_name: &str) -> Result<Vec<u8>> {
422 self.send_message(ClientMessage::ReadVariable {
423 service_id: service_id.to_string(),
424 variable_name: variable_name.to_string(),
425 })
426 .await?;
427
428 if let Some(conn) = &*self.connection.read().await {
430 if let Ok(Ok(Some(ServerMessage::VariableData { data, .. }))) =
431 tokio::time::timeout(Duration::from_secs(10), conn.recv()).await
432 {
433 Ok(data)
434 } else {
435 Err(CommyError::Timeout)
436 }
437 } else {
438 Err(CommyError::ConnectionLost(
439 "Connection lost during read_variable".to_string(),
440 ))
441 }
442 }
443
444 pub async fn write_variable(
446 &self,
447 service_id: &str,
448 variable_name: &str,
449 data: Vec<u8>,
450 ) -> Result<()> {
451 self.send_message(ClientMessage::WriteVariable {
452 service_id: service_id.to_string(),
453 variable_name: variable_name.to_string(),
454 data,
455 })
456 .await?;
457
458 Ok(())
459 }
460
461 pub async fn subscribe(&self, service_id: &str, variable_name: &str) -> Result<()> {
463 self.send_message(ClientMessage::Subscribe {
464 service_id: service_id.to_string(),
465 variable_name: variable_name.to_string(),
466 })
467 .await?;
468
469 Ok(())
470 }
471
472 pub async fn unsubscribe(&self, service_id: &str, variable_name: &str) -> Result<()> {
474 self.send_message(ClientMessage::Unsubscribe {
475 service_id: service_id.to_string(),
476 variable_name: variable_name.to_string(),
477 })
478 .await?;
479
480 Ok(())
481 }
482
483 pub async fn heartbeat(&self) -> Result<()> {
485 self.send_message(ClientMessage::Heartbeat {
486 client_id: self.client_id.clone(),
487 })
488 .await?;
489
490 if let Some(conn) = &*self.connection.read().await {
492 match tokio::time::timeout(Duration::from_secs(10), conn.recv()).await {
493 Ok(Ok(Some(ServerMessage::Heartbeat { .. }))) => {
494 }
496 _ => {
497 }
499 }
500 }
501
502 let mut state = self.state.write().await;
503 state.touch();
504
505 Ok(())
506 }
507
508 pub async fn disconnect(&self) -> Result<()> {
510 self.send_message(ClientMessage::Disconnect {
511 client_id: self.client_id.clone(),
512 })
513 .await?;
514
515 let mut conn_guard = self.connection.write().await;
516 *conn_guard = None;
517
518 let mut state = self.state.write().await;
519 state.reset();
520
521 Ok(())
522 }
523
524 pub async fn is_connected(&self) -> bool {
526 self.connection.read().await.is_some()
527 }
528
529 pub async fn connection_state(&self) -> ConnectionState {
531 let state = self.state.read().await;
532 state.connection_state
533 }
534
535 pub async fn authenticated_tenants(&self) -> Vec<String> {
537 let state = self.state.read().await;
538 state
539 .authenticated_tenants()
540 .into_iter()
541 .map(|s| s.to_string())
542 .collect()
543 }
544
545 pub async fn is_authenticated_to(&self, tenant_id: &str) -> bool {
547 let state = self.state.read().await;
548 state.is_authenticated_to(tenant_id)
549 }
550
551 pub async fn idle_seconds(&self) -> u64 {
553 let state = self.state.read().await;
554 state.idle_seconds()
555 }
556
557 async fn send_message(&self, msg: ClientMessage) -> Result<()> {
559 let result = self.send_message_once(msg.clone()).await;
561
562 if let Err(CommyError::ConnectionLost(_)) = result {
564 let current_attempts = self.reconnect_attempts.fetch_add(1, Ordering::SeqCst);
565
566 if current_attempts < self.max_reconnect_attempts as u64 {
567 let delay = Duration::from_secs(2_u64.pow(current_attempts as u32).min(16));
569 tokio::time::sleep(delay).await;
570
571 if let Ok(()) = self._connect_impl().await {
573 return self.send_message_once(msg).await;
575 }
576 }
577
578 return Err(CommyError::ConnectionLost(format!(
579 "Connection lost after {} reconnection attempts",
580 current_attempts + 1
581 )));
582 }
583
584 result
585 }
586
587 async fn send_message_once(&self, msg: ClientMessage) -> Result<()> {
589 let conn_guard = self.connection.read().await;
590 if let Some(conn) = conn_guard.as_ref() {
591 conn.send(msg).await?;
592
593 let mut state = self.state.write().await;
594 state.touch();
595
596 Ok(())
597 } else {
598 Err(CommyError::ConnectionLost(
599 "Connection not established".to_string(),
600 ))
601 }
602 }
603
604 async fn start_heartbeat_task(&self) {
606 let mut task_guard = self.heartbeat_task.write().await;
608 if let Some(handle) = task_guard.take() {
609 handle.abort();
610 }
611
612 let interval = self.heartbeat_interval;
613 let client_id = self.client_id.clone();
614 let connection = Arc::clone(&self.connection);
615 let state = Arc::clone(&self.state);
616
617 let handle = tokio::spawn(async move {
618 let mut interval_timer = tokio::time::interval(interval);
619 interval_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
620
621 loop {
622 interval_timer.tick().await;
623
624 let conn_guard = connection.read().await;
626 if let Some(conn) = conn_guard.as_ref() {
627 let heartbeat_msg = ClientMessage::Heartbeat {
629 client_id: client_id.clone(),
630 };
631
632 if conn.send(heartbeat_msg).await.is_ok() {
633 let mut state_guard = state.write().await;
635 state_guard.touch();
636 } else {
637 break;
639 }
640 } else {
641 break;
643 }
644 }
645 });
646
647 *task_guard = Some(handle);
648 }
649
650 #[inline]
652 async fn _init_file_watcher_impl(&self) -> Result<()> {
653 let watcher = VariableFileWatcher::new(None).await?;
654 let watcher = Arc::new(watcher);
655 watcher.start_watching().await?;
656 *self.file_watcher.write().await = Some(watcher);
657 Ok(())
658 }
659
660 pub async fn init_file_watcher(&self) -> Result<()> {
662 self._init_file_watcher_impl().await
663 }
664
665 pub async fn get_virtual_service_file(
671 &self,
672 tenant_id: &str,
673 service_name: &str,
674 ) -> Result<Arc<VirtualVariableFile>> {
675 let service_id = format!("{}_{}", tenant_id, service_name);
676
677 {
679 let vfiles = self.virtual_files.read().await;
680 if let Some(vf) = vfiles.get(&service_id) {
681 return Ok(Arc::clone(vf));
682 }
683 }
684
685 let vf = Arc::new(VirtualVariableFile::new(
687 service_id.clone(),
688 service_name.to_string(),
689 tenant_id.to_string(),
690 ));
691
692 if let Some(watcher_guard) = self.file_watcher.read().await.as_ref() {
694 watcher_guard
695 .register_virtual_file(service_id.clone(), Arc::clone(&vf))
696 .await?;
697 }
698
699 let mut vfiles = self.virtual_files.write().await;
701 vfiles.insert(service_id, Arc::clone(&vf));
702
703 Ok(vf)
704 }
705
706 #[inline]
711 async fn _start_file_monitoring_impl(&self) -> Result<()> {
712 let watcher = self.file_watcher.read().await;
713 if watcher.is_some() {
714 Ok(())
716 } else {
717 drop(watcher);
718 self._init_file_watcher_impl().await
719 }
720 }
721
722 pub async fn start_file_monitoring(&self) -> Result<()> {
727 self._start_file_monitoring_impl().await
728 }
729
730 pub async fn wait_for_file_change(&self) -> Result<Option<crate::watcher::FileChangeEvent>> {
732 let watcher = self.file_watcher.read().await;
733 if let Some(w) = watcher.as_ref() {
734 Ok(w.next_change().await)
735 } else {
736 Err(CommyError::InvalidState(
737 "File watcher not initialized. Call start_file_monitoring() first".to_string(),
738 ))
739 }
740 }
741
742 pub async fn try_get_file_change(&self) -> Result<Option<crate::watcher::FileChangeEvent>> {
744 let watcher = self.file_watcher.read().await;
745 if let Some(w) = watcher.as_ref() {
746 Ok(w.try_next_change().await)
747 } else {
748 Err(CommyError::InvalidState(
749 "File watcher not initialized. Call start_file_monitoring() first".to_string(),
750 ))
751 }
752 }
753
754 pub async fn stop_file_monitoring(&self) -> Result<()> {
756 if let Some(watcher) = self.file_watcher.write().await.take() {
757 watcher.stop_watching().await?;
758 }
759 Ok(())
760 }
761}
762
763#[cfg(test)]
764mod tests {
765 use super::*;
766
767 #[test]
768 fn test_client_creation() {
769 let client = Client::new("wss://localhost:9000");
770 assert!(!client.id().is_empty());
771 assert_eq!(client.server_url(), "wss://localhost:9000");
772 }
773
774 #[test]
775 fn test_client_with_custom_id() {
776 let client = Client::with_id("wss://localhost:9000", "my_client");
777 assert_eq!(client.id(), "my_client");
778 }
779
780 #[tokio::test]
781 async fn test_is_connected_initially_false() {
782 let client = Client::new("wss://localhost:9000");
783 assert!(!client.is_connected().await);
784 }
785
786 #[tokio::test]
787 async fn test_idle_seconds() {
788 let client = Client::new("wss://localhost:9000");
789 let idle = client.idle_seconds().await;
790 assert!(idle < 2);
791 }
792}