1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
//! # Group service
//!
//! This module contains the group service that handles incoming requests
//! for group management.

use crate::state::{
    group::{Group, GroupId},
    parameters::Parameters,
};
use serde::{Deserialize, Serialize};
use strum::{Display, EnumString};

#[cfg(feature = "server")]
use super::{notification::Notification, Service, ServiceResponse};
#[cfg(feature = "server")]
use crate::state::{ClientId, State};
#[cfg(feature = "server")]
use json_rpc2::{Error, Request};
#[cfg(feature = "server")]
use std::str::FromStr;
#[cfg(feature = "server")]
use tokio::sync::Mutex;

/// Prefix for group routes.
pub const ROUTE_PREFIX: &str = "group";

/// Available group methods.
#[derive(Debug, Display, EnumString)]
pub enum GroupMethod {
    #[strum(serialize = "group_create")]
    GroupCreate,
    #[strum(serialize = "group_join")]
    GroupJoin,
}

/// Group create request.
#[derive(Deserialize, Serialize)]
pub struct GroupCreateRequest {
    pub parameters: Parameters,
}

/// Group create response.
#[derive(Deserialize, Serialize)]
pub struct GroupCreateResponse {
    pub group: Group,
}

/// Group join request.
#[derive(Deserialize, Serialize)]
pub struct GroupJoinRequest {
    #[serde(rename = "groupId")]
    pub group_id: GroupId,
}

/// Group join response.
#[derive(Deserialize, Serialize)]
pub struct GroupJoinResponse {
    pub group: Group,
}

/// Group service that handles incoming requests and maps
/// them to the corresponding methods.
#[cfg(feature = "server")]
pub struct GroupService;

#[axum::async_trait]
#[cfg(feature = "server")]
impl Service for GroupService {
    async fn handle(
        &self,
        req: &Request,
        ctx: (
            std::sync::Arc<State>,
            std::sync::Arc<Mutex<Vec<Notification>>>,
        ),
        client_id: ClientId,
    ) -> ServiceResponse {
        let method =
            GroupMethod::from_str(req.method()).map_err(|_| json_rpc2::Error::MethodNotFound {
                name: req.method().to_string(),
                id: req.id().clone(),
            })?;
        let response = match method {
            GroupMethod::GroupCreate => self.group_create(req, ctx, client_id).await?,
            GroupMethod::GroupJoin => self.group_join(req, ctx, client_id).await?,
        };
        Ok(response)
    }
}

#[cfg(feature = "server")]
impl GroupService {
    async fn group_create(
        &self,
        req: &Request,
        ctx: (
            std::sync::Arc<State>,
            std::sync::Arc<Mutex<Vec<Notification>>>,
        ),
        client_id: ClientId,
    ) -> ServiceResponse {
        tracing::info!("Creating a new group");
        let params: GroupCreateRequest = req.deserialize()?;
        let (state, _) = ctx;
        params
            .parameters
            .validate()
            .map_err(|e| Error::InvalidParams {
                id: req.id().clone(),
                data: e.to_string(),
            })?;

        let group = state.add_group(params.parameters).await;
        state
            .join_group(group.id, client_id)
            .await
            .map_err(|e| Error::from(Box::from(e)))?;
        tracing::info!(group_id = group.id().to_string(), "Group created");
        let res = serde_json::to_value(GroupCreateResponse { group })
            .map_err(|e| Error::from(Box::from(e)))?;
        Ok(Some((req, res).into()))
    }

    async fn group_join(
        &self,
        req: &Request,
        ctx: (
            std::sync::Arc<State>,
            std::sync::Arc<Mutex<Vec<Notification>>>,
        ),
        client_id: ClientId,
    ) -> ServiceResponse {
        let params: GroupJoinRequest = req.deserialize()?;
        tracing::info!(
            group_id = params.group_id.to_string(),
            "Joining client to group"
        );
        let (state, _) = ctx;
        let group = state
            .join_group(params.group_id, client_id)
            .await
            .map_err(|e| Error::InvalidParams {
                id: req.id().clone(),
                data: e.to_string(),
            })?;
        let res = serde_json::to_value(GroupJoinResponse { group })
            .map_err(|e| Error::from(Box::from(e)))?;
        Ok(Some((req, res).into()))
    }
}