palpo_core/error/
kind_serde.rs

1use std::{
2    borrow::Cow,
3    collections::btree_map::{BTreeMap, Entry},
4    fmt,
5    time::Duration,
6};
7
8use crate::serde::{DeserializeFromCowStr, FromString};
9use serde::{
10    de::{self, Deserialize, Deserializer, MapAccess, Visitor},
11    ser::{self, Serialize, SerializeMap, Serializer},
12};
13use serde_json::from_value as from_json_value;
14
15use super::ErrorKind;
16use crate::PrivOwnedStr;
17
18enum Field<'de> {
19    ErrCode,
20    SoftLogout,
21    RetryAfterMs,
22    RoomVersion,
23    AdminContact,
24    Status,
25    Body,
26    CurrentVersion,
27    Other(Cow<'de, str>),
28}
29
30impl<'de> Field<'de> {
31    fn new(s: Cow<'de, str>) -> Field<'de> {
32        match s.as_ref() {
33            "errcode" => Self::ErrCode,
34            "soft_logout" => Self::SoftLogout,
35            "retry_after_ms" => Self::RetryAfterMs,
36            "room_version" => Self::RoomVersion,
37            "admin_contact" => Self::AdminContact,
38            "status" => Self::Status,
39            "body" => Self::Body,
40            "current_version" => Self::CurrentVersion,
41            _ => Self::Other(s),
42        }
43    }
44}
45
46impl<'de> Deserialize<'de> for Field<'de> {
47    fn deserialize<D>(deserializer: D) -> Result<Field<'de>, D::Error>
48    where
49        D: Deserializer<'de>,
50    {
51        struct FieldVisitor;
52
53        impl<'de> Visitor<'de> for FieldVisitor {
54            type Value = Field<'de>;
55
56            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
57                formatter.write_str("any struct field")
58            }
59
60            fn visit_str<E>(self, value: &str) -> Result<Field<'de>, E>
61            where
62                E: de::Error,
63            {
64                Ok(Field::new(Cow::Owned(value.to_owned())))
65            }
66
67            fn visit_borrowed_str<E>(self, value: &'de str) -> Result<Field<'de>, E>
68            where
69                E: de::Error,
70            {
71                Ok(Field::new(Cow::Borrowed(value)))
72            }
73
74            fn visit_string<E>(self, value: String) -> Result<Field<'de>, E>
75            where
76                E: de::Error,
77            {
78                Ok(Field::new(Cow::Owned(value)))
79            }
80        }
81
82        deserializer.deserialize_identifier(FieldVisitor)
83    }
84}
85
86struct ErrorKindVisitor;
87
88impl<'de> Visitor<'de> for ErrorKindVisitor {
89    type Value = ErrorKind;
90
91    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
92        formatter.write_str("enum ErrorKind")
93    }
94
95    fn visit_map<V>(self, mut map: V) -> Result<ErrorKind, V::Error>
96    where
97        V: MapAccess<'de>,
98    {
99        let mut errcode = None;
100        let mut soft_logout = None;
101        let mut retry_after_ms = None;
102        let mut room_version = None;
103        let mut admin_contact = None;
104        let mut status: Option<()> = None;
105        let mut body: Option<()> = None;
106        let mut current_version = None;
107        let mut extra = BTreeMap::new();
108
109        macro_rules! set_field {
110            (errcode) => {
111                set_field!(@inner errcode)
112            };
113            ($field:ident) => {
114                match errcode {
115                    Some(set_field!(@variant_containing $field)) | None => {
116                        set_field!(@inner $field)
117                    }
118                    // if we already know we're deserializing a different variant to the one
119                    // containing this field, ignore its value.
120                    Some(_) => {
121                        let _ = map.next_value::<de::IgnoredAny>()?;
122                    },
123                }
124            };
125            (@variant_containing soft_logout) => { ErrCode::UnknownToken };
126            (@variant_containing retry_after_ms) => { ErrCode::LimitExceeded };
127            (@variant_containing room_version) => { ErrCode::IncompatibleRoomVersion };
128            (@variant_containing admin_contact) => { ErrCode::ResourceLimitExceeded };
129            (@variant_containing status) => { ErrCode::BadStatus };
130            (@variant_containing body) => { ErrCode::BadStatus };
131            (@variant_containing current_version) => { ErrCode::WrongRoomKeysVersion };
132            (@inner $field:ident) => {
133                {
134                    if $field.is_some() {
135                        return Err(de::Error::duplicate_field(stringify!($field)));
136                    }
137                    $field = Some(map.next_value()?);
138                }
139            };
140        }
141
142        while let Some(key) = map.next_key()? {
143            match key {
144                Field::ErrCode => set_field!(errcode),
145                Field::SoftLogout => set_field!(soft_logout),
146                Field::RetryAfterMs => set_field!(retry_after_ms),
147                Field::RoomVersion => set_field!(room_version),
148                Field::AdminContact => set_field!(admin_contact),
149                Field::Status => set_field!(status),
150                Field::Body => set_field!(body),
151                Field::CurrentVersion => set_field!(current_version),
152                Field::Other(other) => match extra.entry(other.into_owned()) {
153                    Entry::Vacant(v) => {
154                        v.insert(map.next_value()?);
155                    }
156                    Entry::Occupied(o) => {
157                        return Err(de::Error::custom(format!("duplicate field `{}`", o.key())));
158                    }
159                },
160            }
161        }
162
163        let errcode = errcode.ok_or_else(|| de::Error::missing_field("errcode"))?;
164
165        Ok(match errcode {
166            ErrCode::Forbidden => ErrorKind::Forbidden,
167            ErrCode::UnknownToken => ErrorKind::UnknownToken {
168                soft_logout: soft_logout
169                    .map(from_json_value)
170                    .transpose()
171                    .map_err(de::Error::custom)?
172                    .unwrap_or_default(),
173            },
174            ErrCode::MissingToken => ErrorKind::MissingToken,
175            ErrCode::BadJson => ErrorKind::BadJson,
176            ErrCode::NotJson => ErrorKind::NotJson,
177            ErrCode::NotFound => ErrorKind::NotFound,
178            ErrCode::LimitExceeded => ErrorKind::LimitExceeded {
179                retry_after_ms: retry_after_ms
180                    .map(from_json_value::<u64>)
181                    .transpose()
182                    .map_err(de::Error::custom)?
183                    .map(Into::into)
184                    .map(Duration::from_millis),
185            },
186            ErrCode::Unknown => ErrorKind::Unknown,
187            ErrCode::Unrecognized => ErrorKind::Unrecognized,
188            ErrCode::Unauthorized => ErrorKind::Unauthorized,
189            ErrCode::UserDeactivated => ErrorKind::UserDeactivated,
190            ErrCode::UserInUse => ErrorKind::UserInUse,
191            ErrCode::InvalidUsername => ErrorKind::InvalidUsername,
192            ErrCode::RoomInUse => ErrorKind::RoomInUse,
193            ErrCode::InvalidRoomState => ErrorKind::InvalidRoomState,
194            ErrCode::ThreepidInUse => ErrorKind::ThreepidInUse,
195            ErrCode::ThreepidNotFound => ErrorKind::ThreepidNotFound,
196            ErrCode::ThreepidAuthFailed => ErrorKind::ThreepidAuthFailed,
197            ErrCode::ThreepidDenied => ErrorKind::ThreepidDenied,
198            ErrCode::ServerNotTrusted => ErrorKind::ServerNotTrusted,
199            ErrCode::UnsupportedRoomVersion => ErrorKind::UnsupportedRoomVersion,
200            ErrCode::IncompatibleRoomVersion => ErrorKind::IncompatibleRoomVersion {
201                room_version: from_json_value(room_version.ok_or_else(|| de::Error::missing_field("room_version"))?)
202                    .map_err(de::Error::custom)?,
203            },
204            ErrCode::BadState => ErrorKind::BadState,
205            ErrCode::GuestAccessForbidden => ErrorKind::GuestAccessForbidden,
206            ErrCode::CaptchaNeeded => ErrorKind::CaptchaNeeded,
207            ErrCode::CaptchaInvalid => ErrorKind::CaptchaInvalid,
208            ErrCode::MissingParam => ErrorKind::MissingParam,
209            ErrCode::InvalidParam => ErrorKind::InvalidParam,
210            ErrCode::TooLarge => ErrorKind::TooLarge,
211            ErrCode::Exclusive => ErrorKind::Exclusive,
212            ErrCode::ResourceLimitExceeded => ErrorKind::ResourceLimitExceeded {
213                admin_contact: from_json_value(admin_contact.ok_or_else(|| de::Error::missing_field("admin_contact"))?)
214                    .map_err(de::Error::custom)?,
215            },
216            ErrCode::CannotLeaveServerNoticeRoom => ErrorKind::CannotLeaveServerNoticeRoom,
217            ErrCode::WeakPassword => ErrorKind::WeakPassword,
218            ErrCode::UnableToAuthorizeJoin => ErrorKind::UnableToAuthorizeJoin,
219            ErrCode::UnableToGrantJoin => ErrorKind::UnableToGrantJoin,
220            ErrCode::BadAlias => ErrorKind::BadAlias,
221            ErrCode::DuplicateAnnotation => ErrorKind::DuplicateAnnotation,
222            ErrCode::NotYetUploaded => ErrorKind::NotYetUploaded,
223            ErrCode::CannotOverwriteMedia => ErrorKind::CannotOverwriteMedia,
224            #[cfg(feature = "unstable-msc3575")]
225            ErrCode::UnknownPos => ErrorKind::UnknownPos,
226            ErrCode::UrlNotSet => ErrorKind::UrlNotSet,
227            ErrCode::BadStatus => ErrorKind::BadStatus,
228            ErrCode::ConnectionFailed => ErrorKind::ConnectionFailed,
229            ErrCode::ConnectionTimeout => ErrorKind::ConnectionTimeout,
230            ErrCode::WrongRoomKeysVersion => ErrorKind::WrongRoomKeysVersion {
231                current_version: from_json_value(
232                    current_version.ok_or_else(|| de::Error::missing_field("current_version"))?,
233                )
234                .map_err(de::Error::custom)?,
235            },
236            ErrCode::_Custom(errcode) => ErrorKind::_Custom { errcode, extra },
237        })
238    }
239}
240
241#[derive(FromString, DeserializeFromCowStr)]
242#[palpo_enum(rename_all = "M_MATRIX_ERROR_CASE")]
243enum ErrCode {
244    Forbidden,
245    UnknownToken,
246    MissingToken,
247    BadJson,
248    NotJson,
249    NotFound,
250    LimitExceeded,
251    Unknown,
252    Unrecognized,
253    Unauthorized,
254    UserDeactivated,
255    UserInUse,
256    InvalidUsername,
257    RoomInUse,
258    InvalidRoomState,
259    ThreepidInUse,
260    ThreepidNotFound,
261    ThreepidAuthFailed,
262    ThreepidDenied,
263    ServerNotTrusted,
264    UnsupportedRoomVersion,
265    IncompatibleRoomVersion,
266    BadState,
267    GuestAccessForbidden,
268    CaptchaNeeded,
269    CaptchaInvalid,
270    MissingParam,
271    InvalidParam,
272    TooLarge,
273    Exclusive,
274    ResourceLimitExceeded,
275    CannotLeaveServerNoticeRoom,
276    WeakPassword,
277    UnableToAuthorizeJoin,
278    UnableToGrantJoin,
279    BadAlias,
280    DuplicateAnnotation,
281    #[palpo_enum(alias = "FI.MAU.MSC2246_NOT_YET_UPLOADED")]
282    NotYetUploaded,
283    #[palpo_enum(alias = "FI.MAU.MSC2246_CANNOT_OVERWRITE_MEDIA")]
284    CannotOverwriteMedia,
285    #[cfg(feature = "unstable-msc3575")]
286    UnknownPos,
287    UrlNotSet,
288    BadStatus,
289    ConnectionFailed,
290    ConnectionTimeout,
291    WrongRoomKeysVersion,
292    _Custom(PrivOwnedStr),
293}
294
295impl<'de> Deserialize<'de> for ErrorKind {
296    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
297    where
298        D: Deserializer<'de>,
299    {
300        deserializer.deserialize_map(ErrorKindVisitor)
301    }
302}
303
304impl Serialize for ErrorKind {
305    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
306    where
307        S: Serializer,
308    {
309        let mut st = serializer.serialize_map(None)?;
310        st.serialize_entry("errcode", self.as_ref())?;
311        match self {
312            Self::UnknownToken { soft_logout: true } => {
313                st.serialize_entry("soft_logout", &true)?;
314            }
315            Self::LimitExceeded {
316                retry_after_ms: Some(duration),
317            } => {
318                st.serialize_entry(
319                    "retry_after_ms",
320                    &u64::try_from(duration.as_millis()).map_err(ser::Error::custom)?,
321                )?;
322            }
323            Self::IncompatibleRoomVersion { room_version } => {
324                st.serialize_entry("room_version", room_version)?;
325            }
326            Self::ResourceLimitExceeded { admin_contact } => {
327                st.serialize_entry("admin_contact", admin_contact)?;
328            }
329            Self::_Custom { extra, .. } => {
330                for (k, v) in extra {
331                    st.serialize_entry(k, v)?;
332                }
333            }
334            _ => {}
335        }
336        st.end()
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use serde_json::{from_value as from_json_value, json};
343
344    use super::ErrorKind;
345    use crate::room_version_id;
346
347    #[test]
348    fn deserialize_forbidden() {
349        let deserialized: ErrorKind = from_json_value(json!({ "errcode": "M_FORBIDDEN" })).unwrap();
350        assert_eq!(deserialized, ErrorKind::Forbidden);
351    }
352
353    #[test]
354    fn deserialize_forbidden_with_extra_fields() {
355        let deserialized: ErrorKind = from_json_value(json!({
356            "errcode": "M_FORBIDDEN",
357            "error": "…",
358        }))
359        .unwrap();
360
361        assert_eq!(deserialized, ErrorKind::Forbidden);
362    }
363
364    #[test]
365    fn deserialize_incompatible_room_version() {
366        let deserialized: ErrorKind = from_json_value(json!({
367            "errcode": "M_INCOMPATIBLE_ROOM_VERSION",
368            "room_version": "7",
369        }))
370        .unwrap();
371
372        assert_eq!(
373            deserialized,
374            ErrorKind::IncompatibleRoomVersion {
375                room_version: room_version_id!("7")
376            }
377        );
378    }
379}