Skip to main content

scram_rs/
scram_cb.rs

1/*-
2 * Scram-rs - a SCRAM authentification authorization library
3 * 
4 * Copyright (C) 2021  Aleksandr Morozov
5 * Copyright (C) 2025 Aleksandr Morozov
6 * 
7 * The syslog-rs crate can be redistributed and/or modified
8 * under the terms of either of the following licenses:
9 *
10 *   1. the Mozilla Public License Version 2.0 (the “MPL”) OR
11 *
12 *   2. The MIT License (MIT)
13 *                     
14 *   3. EUROPEAN UNION PUBLIC LICENCE v. 1.2 EUPL © the European Union 2007, 2016
15 */
16
17
18#[cfg(feature = "std")]
19use std::fmt;
20#[cfg(not(feature = "std"))]
21use core::fmt;
22
23#[cfg(not(feature = "std"))]
24use alloc::vec::Vec;
25
26#[cfg(not(feature = "std"))]
27use crate::alloc::string::ToString;
28
29use base64::Engine;
30use base64::engine::general_purpose;
31
32use crate::scram_cbh::{ScramCbHelper};
33use crate::{ScramRuntimeError, ScramServerError};
34
35use super::scram_common::ScramType;
36use super::scram_error::{ScramResult, ScramErrorCode};
37use super::scram_error;
38
39/// A channel binding type picked by client.
40#[derive(Copy, Clone, Debug, Eq, PartialEq)]
41pub enum ChannelBindType
42{
43    /// No channel binding data.
44    None,
45    /// Advertise that the client does not think the server supports channel binding.
46    Unsupported,
47    /// p=tls-unique channel binding data.
48    TlsUnique,
49    /// p=tls-server-end-point
50    TlsServerEndpoint,
51    /// p=tls-exporter
52    TlsExporter,
53}
54
55impl fmt::Display for ChannelBindType
56{
57    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result 
58    {
59        match *self 
60        {
61            Self::None => write!(f, "None"),
62            Self::Unsupported => write!(f, "Unsupported"),
63            Self::TlsUnique => write!(f, "TlsUnique"),
64            Self::TlsServerEndpoint => write!(f, "TlsServerEndpoint"),
65            Self::TlsExporter => write!(f, "TlsExporter"),
66        }
67    }
68}
69
70/// Converts channel binding type directly from string, for server side use only
71impl TryFrom<&str> for ChannelBindType
72{
73    type Error = ScramRuntimeError;
74
75    fn try_from(value: &str) -> Result<Self, Self::Error> 
76    {
77        match value
78        {
79            "n" => 
80                return Ok(Self::None),
81            "y" => 
82                return Ok(Self::Unsupported),
83            "none" => 
84                return Ok(Self::None),
85            "unsupported" => 
86                return Ok(Self::Unsupported),
87            "tls-unique" => 
88                return Ok(Self::TlsUnique),
89            "tls-server-end-point" => 
90                return Ok(Self::TlsServerEndpoint),
91            _ => 
92                scram_error!(ScramErrorCode::ProtocolViolation, ScramServerError::ChannelBindingNotSupported, 
93                    "Unknown channel bind type: '{}'", value),
94        }
95    }
96}
97
98impl ChannelBindType
99{
100    /// Initializes enum as n,,
101    pub 
102    fn n() -> Self
103    {
104        return Self::None;
105    }
106
107    /// Initializes enum as y,,
108    pub 
109    fn y() -> Self
110    {
111        return Self::Unsupported;
112    }
113
114    /// Initializes enum as p=tls-server-end-point
115    pub 
116    fn tls_server_endpoint() -> Self
117    {
118        return Self::TlsServerEndpoint;
119    }
120
121    /// Initializes enum as p=tls-unique
122    pub 
123    fn tls_unique() -> Self
124    {
125        return Self::TlsUnique;
126    }
127
128    /// Initializes p=tls-exporter
129    pub 
130    fn tls_exporter() -> Self
131    {
132        return Self::TlsExporter;
133    }
134
135    /// Converts the enum [ChannelBindType] to protocol header text 
136    pub 
137    fn convert2header(&self) -> &str
138    {
139        match self
140        {
141            Self::None => 
142                return "n,,",
143            Self::Unsupported => 
144                return "y,,",
145            Self::TlsUnique => 
146                return "p=tls-unique,,",
147            Self::TlsServerEndpoint => 
148                return "p=tls-server-end-point,,",
149            Self::TlsExporter => 
150                return "p=tls-exporter,,",
151        }
152    }
153 
154    pub(crate)  
155    fn get_cb_data_raw(&self, sbh: &dyn ScramCbHelper) -> ScramResult<Vec<u8>>
156    {
157        match self
158        {
159            Self::None => 
160                return Ok(b"".to_vec()),
161            Self::Unsupported => 
162                return Ok(b"".to_vec()),
163            Self::TlsUnique => 
164                return sbh.get_tls_unique(),
165            Self::TlsServerEndpoint => 
166                return sbh.get_tls_server_endpoint(),
167            Self::TlsExporter => 
168                return sbh.get_tls_exporter(),
169        }
170    }
171
172    /// Verifies the client initial request of the Channel Bind type
173    /// If client picks SCRAM-? without -PLUS extension, then it should not
174    /// require any channel binding i.e n -(None) or y-(Unsupported)
175    /// 
176    /// # Arguments
177    /// 
178    /// * `st` - picked SCRAM type
179    /// 
180    /// # Returns
181    /// 
182    /// * [ScramResult] - returns nothing in payload or error
183    pub 
184    fn server_initial_verify_client_cb(&self, st: &ScramType) -> ScramResult<()>
185    {
186        if st.scram_chan_bind == true
187        {
188            // server with channel binding support
189            match *self
190            {
191                Self::TlsUnique => return Ok(()),
192
193                Self::TlsServerEndpoint => return Ok(()),
194
195                Self::TlsExporter => return Ok(()),
196
197                Self::None => 
198                    scram_error!(
199                        ScramErrorCode::MalformedScramMsg,
200                        ScramServerError::ChannelBindingsDontMatch,
201                        "malformed message, client selected *-PLUS but message did not include cb data!"
202                    ),
203
204                //if client pickes -PLUS and sends y(Unsupported) then this is malformed message
205                Self::Unsupported => 
206                    scram_error!(
207                        ScramErrorCode::MalformedScramMsg,
208                        ScramServerError::ChannelBindingsDontMatch,
209                        "malformed message, client picked -PLUS, but did not provide cb data"
210                    ),
211            }
212        }
213        else
214        {
215            // -PLUS was not picked
216            match *self
217            {
218                Self::TlsUnique | Self::TlsServerEndpoint | Self::TlsExporter => 
219                    scram_error!(
220                        ScramErrorCode::MalformedScramMsg,
221                        ScramServerError::ChannelBindingsDontMatch,
222                        "client provided channel binding data while picking SCRAM without -PLUS extension!"
223                    ),
224
225                Self::None => return Ok(()),
226
227                // client picks SCRAM-? and thinks we don't support channel binding
228                Self::Unsupported => return Ok(()),
229            }
230        }
231    }
232
233    /// Server uses this function to verify the the client channel bind
234    /// in final message.
235    /// 
236    /// # Arguments
237    /// 
238    /// * `st` - [ScramType] a current scram type
239    /// 
240    /// * `cb_attr` - a received channel binding data from client in base64 format
241    /// 
242    /// * `endpoint_hash` - a servers TLS end point certificate
243    /// 
244    /// # Returns
245    /// 
246    /// * [ScramResult] nothing in payload or error
247    pub(crate) 
248    fn server_final_verify_client_cb(
249        &self, 
250        st: &ScramType, 
251        cb_attr: &str,
252        sbh: &dyn ScramCbHelper,
253    ) -> ScramResult<()>
254    {
255        // verify input
256        // If we are not using channel binding, the binding data is expected
257        // to always be "biws", which is "n,," base64-encoded, or "eSws",
258        // which is "y,,".  We also have to check whether the flag is the same
259        // one that the client originally sent. auth-scram_8c_source.c:1310
260
261        let comp_cb_attr = 
262            match *self
263            {
264                Self::TlsExporter => 
265                {
266                    if st.scram_chan_bind == false
267                    {
268                        scram_error!(
269                            ScramErrorCode::InternalError,
270                            ScramServerError::OtherError,
271                            "assertion trap: cb_type: {}, scram_type: {}, does not \
272                                include SCRAM channel binding",
273                            self, st
274                        );
275                    }
276
277                    let expt_tls = 
278                        match sbh.get_tls_exporter()
279                        {
280                            Ok(r) => r,   
281                            Err(e) => 
282                                scram_error!(
283                                    ScramErrorCode::InternalError,
284                                    ScramServerError::OtherError,
285                                    "assertion trap: cb_type: {}, scram_type: {} \
286                                    TlsExporter requires endpoint data from TLS connection! \
287                                    Error returned: '{}'", 
288                                    self, st, e
289                                ),
290                        };
291
292                    let header = 
293                        [
294                            self.convert2header().as_bytes(), //"p=tls-server-end-point,,".as_bytes(),
295                            expt_tls.as_slice(),
296                            //cbind_data
297                        ].concat();
298                    
299                    let bheader = general_purpose::STANDARD.encode(header);
300                    
301                    bheader
302                },
303                Self::TlsUnique =>
304                {
305                    if st.scram_chan_bind == false
306                    {
307                        scram_error!(
308                            ScramErrorCode::InternalError,
309                            ScramServerError::OtherError,
310                            "assertion trap: cb_type: {}, scram_type: {}, does not \
311                                include SCRAM channel binding",
312                            self, st
313                        );
314                    }
315
316                    let uniq_tls = 
317                        match sbh.get_tls_unique()
318                        {
319                            Ok(r) => r,   
320                            Err(e) => 
321                                scram_error!(
322                                    ScramErrorCode::InternalError,
323                                    ScramServerError::OtherError,
324                                    "assertion trap: cb_type: {}, scram_type: {} \
325                                    TlsUnique requires endpoint data from TLS connection! \
326                                    Error returned: '{}'", 
327                                    self, st, e
328                                ),
329                        };
330
331                    let header = 
332                        [
333                            self.convert2header().as_bytes(), //"p=tls-server-end-point,,".as_bytes(),
334                            uniq_tls.as_slice(),
335                            //cbind_data
336                        ].concat();
337                    
338                    let bheader = general_purpose::STANDARD.encode(header);
339                    
340                    bheader
341                },
342                Self::TlsServerEndpoint =>
343                {
344                    if st.scram_chan_bind == false
345                    {
346                        scram_error!(
347                            ScramErrorCode::InternalError,
348                            ScramServerError::OtherError,
349                            "assertion trap: cb_type: {}, scram_type: {}, does not \
350                                include SCRAM channel binding",
351                            self, st
352                        );
353                    }
354
355                    let endp_cert_hash = 
356                        match sbh.get_tls_server_endpoint()
357                        {
358                            Ok(r) => r,   
359                            Err(e) => 
360                                scram_error!(
361                                    ScramErrorCode::InternalError,
362                                    ScramServerError::OtherError,
363                                    "assertion trap: cb_type: {}, scram_type: {} \
364                                    TlsServerEndpoint requires endpoint data from TLS connection! \
365                                    Error returned: '{}'", 
366                                    self, st, e
367                                ),
368                        };
369
370                    //get the data from ChannelBindingData which contains 
371                    //hash data of server's SSL certificate and combine it
372                    let header = 
373                        [
374                            self.convert2header().as_bytes(), //"p=tls-server-end-point,,".as_bytes(),
375                            endp_cert_hash.as_slice(),
376                            //cbind_data
377                        ].concat();
378                    
379                    let bheader = general_purpose::STANDARD.encode(header);
380
381
382                    bheader
383                },
384                Self::Unsupported => 
385                {
386                    "eSws".to_string()
387                },
388                Self::None =>
389                {
390                    "biws".to_string()
391                }
392            };
393
394        if comp_cb_attr.as_str() == cb_attr
395        {
396            return Ok(());
397        }
398        else
399        {
400            scram_error!(
401                ScramErrorCode::VerificationError, 
402                ScramServerError::OtherError,
403                "SCRAM channel binding '{}' check failed! Scram type: {}",
404                self, st
405            );
406        }
407    }
408   
409}