wtx 0.43.1

A collection of different transport implementations and related tools focused primarily on web technologies.
Documentation
use crate::{
  collection::{ArrayVectorU8, Vector},
  database::{
    Executor,
    client::mysql::{
      Config, ExecutorBuffer, MysqlError, MysqlExecutor,
      auth_plugin::AuthPlugin,
      capability::Capability,
      misc::{decode, fetch_msg, write_and_send_packet},
      mysql_executor::DFLT_PACKET_SIZE,
      protocol::{
        auth_switch_req::AuthSwitchReq, auth_switch_res::AuthSwitchRes,
        handshake_req::HandshakeReq, handshake_res::HandshakeRes, ok_res::OkRes,
      },
    },
  },
  misc::{LeaseMut, from_utf8_basic, net::PartitionedFilledBuffer},
  rng::CryptoRng,
  stream::Stream,
};

impl<E, EB, S> MysqlExecutor<E, EB, S>
where
  E: From<crate::Error>,
  EB: LeaseMut<ExecutorBuffer>,
  S: Stream,
{
  pub(crate) async fn connect0<'nb, IS>(
    config: &Config<'_>,
    net_buffer: &'nb mut PartitionedFilledBuffer,
    sequence_id: &mut u8,
    stream: &mut IS,
  ) -> Result<(u64, HandshakeRes<'nb>), E>
  where
    IS: Stream,
  {
    let mut capabilities: u64 = u64::from(Capability::DeprecateEof)
      | u64::from(Capability::FoundRows)
      | u64::from(Capability::IgnoreSpace)
      | u64::from(Capability::MultiResults)
      | u64::from(Capability::MultiStatements)
      | u64::from(Capability::OptionalResultsetMetadata)
      | u64::from(Capability::PluginAuth)
      | u64::from(Capability::PluginAuthLenencData)
      | u64::from(Capability::Protocol41)
      | u64::from(Capability::PsMultiResults)
      | u64::from(Capability::SecureConnection)
      | u64::from(Capability::Transactions);
    if config.db.is_some() {
      capabilities |= u64::from(Capability::ConnectWithDb);
    }
    let _ = fetch_msg(capabilities, net_buffer, sequence_id, stream).await?;
    let mut bytes = net_buffer.current();
    let res: HandshakeRes<'_> = decode(&mut bytes, ())?;
    capabilities &= res.capabilities;
    Ok((capabilities, res))
  }

  pub(crate) async fn connect1(
    (capabilities, sequence_id): (&mut u64, &mut u8),
    config: &Config<'_>,
    encode_buffer: &mut Vector<u8>,
    handshake_res: &HandshakeRes<'_>,
    stream: &mut S,
  ) -> Result<(), E> {
    let tuple = (handshake_res.auth_plugin, config.password);
    let auth_response = if let (Some(plugin), Some(pw)) = tuple {
      Some(plugin.mask_pw(
        (&handshake_res.auth_plugin_data.0, handshake_res.auth_plugin_data.1),
        pw.as_bytes(),
      )?)
    } else {
      None
    };
    let handshake_req = HandshakeReq {
      auth_plugin: handshake_res.auth_plugin,
      auth_response,
      collation: config.collation,
      database: config.db,
      max_packet_size: DFLT_PACKET_SIZE,
      username: config.user,
    };
    write_and_send_packet((capabilities, sequence_id), encode_buffer, handshake_req, stream)
      .await?;
    Ok(())
  }

  pub(crate) async fn connect2<RNG, const IS_TLS: bool>(
    auth_plugin_data: ([u8; 8], ArrayVectorU8<u8, 24>),
    (capabilities, sequence_id): (&mut u64, &mut u8),
    config: &Config<'_>,
    encode_buffer: &mut Vector<u8>,
    net_buffer: &mut PartitionedFilledBuffer,
    mut plugin: Option<AuthPlugin>,
    rng: &mut RNG,
    stream: &mut S,
  ) -> Result<(), E>
  where
    RNG: CryptoRng,
  {
    loop {
      let _ = fetch_msg(*capabilities, net_buffer, sequence_id, stream).await?;
      let mut current = net_buffer.current();
      match current {
        [0, ..] => {
          let _: OkRes = decode(&mut current, ())?;
          break;
        }
        [254, ..] => {
          let other = config.enable_cleartext_plugin;
          let res_rslt: Result<AuthSwitchRes, _> = decode(&mut current, other);
          let res = res_rslt?;
          plugin = Some(res.auth_plugin);
          let bytes = res.auth_plugin.mask_pw(
            if let Some((lhs, rhs)) = &res.data {
              (lhs.as_slice(), rhs.as_slice())
            } else {
              (&[][..], &[][..])
            },
            config.password.unwrap_or_default().as_bytes(),
          )?;
          let payload = AuthSwitchReq(&bytes);
          write_and_send_packet((capabilities, sequence_id), encode_buffer, payload, stream)
            .await?;
        }
        [a, rest @ ..] => {
          let (Some(plugin), Some(password)) = (plugin, &config.password) else {
            return Err(E::from(MysqlError::InvalidConnectionBytes.into()));
          };
          let [b, ..] = rest else {
            return Err(E::from(MysqlError::InvalidConnectionBytes.into()));
          };
          let is_auth_ok = plugin
            .manage_caching_sha2::<_, _, _, IS_TLS>(
              (auth_plugin_data.0, &auth_plugin_data.1),
              [*a, *b],
              (capabilities, sequence_id),
              encode_buffer,
              net_buffer,
              password,
              rng,
              stream,
            )
            .await?;
          if is_auth_ok {
            break;
          };
        }
        _ => {
          return Err(E::from(MysqlError::InvalidConnectionBytes.into()));
        }
      }
    }
    Ok(())
  }

  pub(crate) async fn connect3(&mut self, config: &Config<'_>) -> Result<(), E> {
    let mut buffer = Vector::new();
    buffer.extend_from_copyable_slice("SET ".as_bytes())?;
    let sql_mode_opt = match (config.no_engine_substitution, config.pipes_as_concat) {
      (false, false) => None,
      (false, true) => Some(&b"sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT')),"[..]),
      (true, false) => {
        Some(&b"sql_mode=(SELECT CONCAT(@@sql_mode, ',NO_ENGINE_SUBSTITUTION')),"[..])
      }
      (true, true) => Some(
        &b"sql_mode=(SELECT CONCAT(@@sql_mode, ',NO_ENGINE_SUBSTITUTION,PIPES_AS_CONCAT')),"[..],
      ),
    };
    if let Some(sql_mode) = sql_mode_opt {
      buffer.extend_from_copyable_slice(sql_mode)?;
    }
    if let Some(timezone) = &config.timezone {
      let _ = buffer.extend_from_copyable_slices([b"time_zone=", timezone.as_bytes(), b"',"])?;
    }
    if config.set_names {
      let _ = buffer.extend_from_copyable_slices([
        "NAMES ".as_bytes(),
        config.charset.strings().ident.as_bytes(),
        b" COLLATE ",
        config.collation.strings().ident.as_bytes(),
        b",",
      ])?;
    }
    if buffer.len() > 4 {
      let _ = buffer.pop();
      self.execute_ignored(from_utf8_basic(&buffer).map_err(crate::Error::from)?).await?;
    }
    buffer.clear();
    Ok(())
  }
}