barehttp 0.0.1

A minimal, explicit HTTP client for Rust with no_std support and blocking I/O
Documentation
use crate::error::SocketError;
use crate::socket::{SocketAddr, SocketFlags};

use core::sync::atomic::{AtomicBool, Ordering};
use windows_sys::Win32::Foundation::TRUE;
use windows_sys::Win32::Networking::WinSock::{
  AF_INET, INVALID_SOCKET, IPPROTO_TCP, SD_BOTH, SO_KEEPALIVE, SO_RCVTIMEO, SO_REUSEADDR, SO_SNDTIMEO, SOCK_STREAM,
  SOCKADDR_IN, SOCKET, SOCKET_ERROR, SOL_SOCKET, TCP_NODELAY, WSADATA, WSAGetLastError, WSAStartup, closesocket,
  connect, recv, send, setsockopt, shutdown, socket,
};
use windows_sys::core::BOOL;

static WSA_INITIALIZED: AtomicBool = AtomicBool::new(false);

fn ensure_wsa_initialized() -> Result<(), SocketError> {
  if WSA_INITIALIZED.load(Ordering::Acquire) {
    return Ok(());
  }

  unsafe {
    let mut wsa_data: WSADATA = core::mem::zeroed();
    let result = WSAStartup(0x0202, &raw mut wsa_data);
    if result != 0 {
      return Err(SocketError::OsError(result));
    }
    WSA_INITIALIZED.store(true, Ordering::Release);
  }

  Ok(())
}

const fn map_wsa_error(code: i32) -> SocketError {
  match code {
    10061 => SocketError::ConnectionRefused,
    10060 => SocketError::TimedOut,
    10035 => SocketError::WouldBlock,
    10004 => SocketError::Interrupted,
    10057 => SocketError::NotConnected,
    10022 => SocketError::InvalidAddress,
    _ => SocketError::OsError(code),
  }
}

fn get_last_wsa_error() -> SocketError {
  unsafe {
    let code = WSAGetLastError();
    map_wsa_error(code)
  }
}

pub struct OsSocket {
  socket: SOCKET,
  connected: bool,
}

impl OsSocket {
  pub fn new() -> Result<Self, SocketError> {
    ensure_wsa_initialized()?;

    unsafe {
      let sock = socket(i32::from(AF_INET), SOCK_STREAM, IPPROTO_TCP);
      if sock == INVALID_SOCKET {
        return Err(get_last_wsa_error());
      }

      Ok(Self {
        socket: sock,
        connected: false,
      })
    }
  }

  pub fn connect(
    &mut self,
    addr: &SocketAddr,
  ) -> Result<(), SocketError> {
    if self.connected {
      return Ok(());
    }

    match addr {
      SocketAddr::Ip { addr: ip_addr, port } => match ip_addr {
        crate::util::IpAddr::V4(ipv4) => self.connect_ipv4(*ipv4, *port)?,
        crate::util::IpAddr::V6(ipv6) => Self::connect_ipv6(*ipv6, *port)?,
      },
      SocketAddr::Hostname { host, port: host_port } => {
        let host_str = core::str::from_utf8(host).map_err(|_| SocketError::InvalidAddress)?;

        let addresses = crate::dns::os::resolve_host(host_str).map_err(|e| match e {
          crate::error::DnsError::ResolutionFailed(code) => SocketError::DnsResolutionFailed(code),
          crate::error::DnsError::NoAddressesFound => SocketError::DnsResolutionFailed(0),
          crate::error::DnsError::InvalidHostname => SocketError::InvalidAddress,
          crate::error::DnsError::Unsupported => SocketError::Unsupported,
          crate::error::DnsError::OsError(code) => SocketError::OsError(code),
        })?;

        let mut last_error = SocketError::ConnectionRefused;
        for ip_addr in &addresses {
          let result = match ip_addr {
            crate::util::IpAddr::V4(ipv4) => self.connect_ipv4(*ipv4, *host_port),
            crate::util::IpAddr::V6(ipv6) => Self::connect_ipv6(*ipv6, *host_port),
          };

          if result.is_ok() {
            return Ok(());
          }
          if let Err(e) = result {
            last_error = e;
          }
        }

        return Err(last_error);
      },
    }

    Ok(())
  }

  fn connect_ipv4(
    &mut self,
    addr: [u8; 4],
    port: u16,
  ) -> Result<(), SocketError> {
    let ip = u32::from_ne_bytes(addr);

    unsafe {
      let mut sockaddr: SOCKADDR_IN = core::mem::zeroed();
      sockaddr.sin_family = AF_INET;
      sockaddr.sin_port = port.to_be();
      sockaddr.sin_addr.S_un.S_addr = ip;

      #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
      let result = connect(
        self.socket,
        &raw const sockaddr as *const _,
        core::mem::size_of::<SOCKADDR_IN>() as i32,
      );

      if result == SOCKET_ERROR {
        return Err(get_last_wsa_error());
      }
    }

    self.connected = true;
    Ok(())
  }

  const fn connect_ipv6(
    _addr: [u16; 8],
    _port: u16,
  ) -> Result<(), SocketError> {
    Err(SocketError::Unsupported)
  }

  pub fn read(
    &mut self,
    buf: &mut [u8],
  ) -> Result<usize, SocketError> {
    if !self.connected {
      return Err(SocketError::NotConnected);
    }

    unsafe {
      #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
      let result = recv(self.socket, buf.as_mut_ptr() as *mut _, buf.len() as i32, 0);

      if result == SOCKET_ERROR {
        return Err(get_last_wsa_error());
      }

      if result == 0 {
        self.connected = false;
      }

      #[allow(clippy::cast_sign_loss)]
      {
        Ok(result as usize)
      }
    }
  }

  pub fn write(
    &mut self,
    buf: &[u8],
  ) -> Result<usize, SocketError> {
    if !self.connected {
      return Err(SocketError::NotConnected);
    }

    unsafe {
      #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
      let result = send(self.socket, buf.as_ptr() as *const _, buf.len() as i32, 0);

      if result == SOCKET_ERROR {
        return Err(get_last_wsa_error());
      }

      #[allow(clippy::cast_sign_loss)]
      {
        Ok(result as usize)
      }
    }
  }

  pub fn shutdown(&mut self) -> Result<(), SocketError> {
    if !self.connected {
      return Ok(());
    }

    unsafe {
      let result = shutdown(self.socket, SD_BOTH);
      if result == SOCKET_ERROR {
        let err = get_last_wsa_error();
        if !matches!(err, SocketError::NotConnected) {
          return Err(err);
        }
      }
    }

    self.connected = false;
    Ok(())
  }

  pub fn set_flags(
    &mut self,
    flags: SocketFlags,
  ) -> Result<(), SocketError> {
    unsafe {
      if flags.contains(SocketFlags::TCP_NODELAY) {
        let val: BOOL = TRUE;
        #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
        let result = setsockopt(
          self.socket,
          IPPROTO_TCP,
          TCP_NODELAY,
          &raw const val as *const _,
          core::mem::size_of::<BOOL>() as i32,
        );
        if result == SOCKET_ERROR {
          return Err(get_last_wsa_error());
        }
      }

      if flags.contains(SocketFlags::KEEPALIVE) {
        let val: BOOL = TRUE;
        #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
        let result = setsockopt(
          self.socket,
          SOL_SOCKET,
          SO_KEEPALIVE,
          &raw const val as *const _,
          core::mem::size_of::<BOOL>() as i32,
        );
        if result == SOCKET_ERROR {
          return Err(get_last_wsa_error());
        }
      }

      if flags.contains(SocketFlags::REUSEADDR) {
        let val: BOOL = TRUE;
        #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
        let result = setsockopt(
          self.socket,
          SOL_SOCKET,
          SO_REUSEADDR,
          &raw const val as *const _,
          core::mem::size_of::<BOOL>() as i32,
        );
        if result == SOCKET_ERROR {
          return Err(get_last_wsa_error());
        }
      }
    }

    Ok(())
  }

  pub fn set_read_timeout(
    &mut self,
    timeout_ms: u32,
  ) -> Result<(), SocketError> {
    unsafe {
      #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
      let result = setsockopt(
        self.socket,
        SOL_SOCKET,
        SO_RCVTIMEO,
        &raw const timeout_ms as *const _,
        core::mem::size_of::<u32>() as i32,
      );
      if result == SOCKET_ERROR {
        return Err(get_last_wsa_error());
      }
    }
    Ok(())
  }

  pub fn set_write_timeout(
    &mut self,
    timeout_ms: u32,
  ) -> Result<(), SocketError> {
    unsafe {
      #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
      let result = setsockopt(
        self.socket,
        SOL_SOCKET,
        SO_SNDTIMEO,
        &raw const timeout_ms as *const _,
        core::mem::size_of::<u32>() as i32,
      );
      if result == SOCKET_ERROR {
        return Err(get_last_wsa_error());
      }
    }
    Ok(())
  }
}

impl Drop for OsSocket {
  fn drop(&mut self) {
    if self.socket != INVALID_SOCKET {
      unsafe {
        closesocket(self.socket);
      }
    }
  }
}