1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
use failure::Fail;

#[derive(Debug, Fail)]
pub enum RegistryError {
    #[fail(display = "Invalid registry value type")]
    InvalidValueType,

    #[fail(display = "{}", _0)]
    Io(#[fail(cause)] ::std::io::Error),

    #[fail(display = "{}", _0)]
    Serialization(super::serialization::SerializationError),
}

impl From<::std::io::Error> for RegistryError {
    fn from(error: ::std::io::Error) -> RegistryError {
        RegistryError::Io(error)
    }
}

impl From<super::serialization::SerializationError> for RegistryError {
    fn from(error: super::serialization::SerializationError) -> RegistryError {
        RegistryError::Serialization(error)
    }
}

use super::serialization;
use super::types;
use winreg::enums::*;
use winreg::{RegKey, RegValue};

const KEY_PATH: &'static str =
    "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\Internet Settings\\Connections";
const KEY_PATH_WOW6432: &'static str =
    "SOFTWARE\\WOW6432Node\\Microsoft\\Windows\\CurrentVersion\\Internet Settings\\Connections";
pub const DEFAULT_CONNECTION_NAME: &'static str = "DefaultConnectionSettings";
pub const WINHTTP_CONNECTION_NAME: &'static str = "WinHttpSettings";

#[derive(Debug, Clone)]
pub enum Target {
    System,
    CurrentUser,
}

#[derive(Debug, Clone)]
pub struct Location {
    pub target: Target,
    pub connection_name: String,
}

pub fn get_current_user_location() -> Location {
    Location {
        target: Target::CurrentUser,
        connection_name: String::from(DEFAULT_CONNECTION_NAME),
    }
}

pub fn get_winhttp_location() -> Location {
    Location {
        target: Target::System,
        connection_name: String::from(WINHTTP_CONNECTION_NAME),
    }
}

fn open_key(target: &Target, write: bool, wow6432: bool) -> Result<RegKey, RegistryError> {
    let root_key = match target {
        Target::System => RegKey::predef(HKEY_LOCAL_MACHINE),
        Target::CurrentUser => RegKey::predef(HKEY_CURRENT_USER),
    };
    let access = if write { KEY_ALL_ACCESS } else { KEY_READ };
    let key_path = if wow6432 { KEY_PATH_WOW6432 } else { KEY_PATH };
    let key = root_key.open_subkey_with_flags(key_path, access)?;
    return Ok(key);
}

fn write_raw(location: &Location, bytes: &Vec<u8>, wow6432: bool) -> Result<(), RegistryError> {
    let value = RegValue {
        vtype: REG_BINARY,
        bytes: bytes.to_owned(),
    };
    let key = open_key(&location.target, true, wow6432)?;
    key.set_raw_value(&location.connection_name, &value)?;
    return Ok(());
}

pub fn write_full(location: &Location, config: &types::FullConfig) -> Result<(), RegistryError> {
    let mut bytes = Vec::new();
    serialization::serialize(config, &mut bytes)?;

    match location.target {
        Target::System => write_raw(location, &bytes, true)?,
        Target::CurrentUser => {}
    }

    write_raw(location, &bytes, false)?;

    return Ok(());
}

fn read_raw(location: &Location) -> Result<Vec<u8>, RegistryError> {
    let key = open_key(&location.target, false, false)?;
    let value = key.get_raw_value(&location.connection_name)?;

    match value.vtype {
        REG_BINARY => Ok(value.bytes),
        _ => Err(RegistryError::InvalidValueType),
    }
}

pub fn read_full(location: &Location) -> Result<types::FullConfig, RegistryError> {
    let bytes = read_raw(location)?;
    let conf = serialization::deserialize(&bytes[..])?;
    return Ok(conf);
}

pub fn get_next_counter(location: &Location) -> u32 {
    let full_result = read_full(location);
    match full_result {
        Ok(full) => full.counter + 1,
        _ => 0,
    }
}

pub fn read(location: &Location) -> Result<types::ProxyConfig, RegistryError> {
    return Ok(read_full(location)?.config);
}

pub fn write(location: &Location, config: types::ProxyConfig) -> Result<(), RegistryError> {
    let full_before = read_full(location)?;
    let full_after = types::FullConfig {
        version: super::IE7_VERSION,
        counter: full_before.counter + 1,
        config,
    };
    write_full(location, &full_after)?;

    Ok(())
}

pub fn update<F>(location: &Location, updater: F) -> Result<(), RegistryError>
where
    F: FnOnce(types::ProxyConfig) -> types::ProxyConfig,
{
    let full_before = read_full(location)?;
    let after = updater(full_before.config);

    let full_after = types::FullConfig {
        version: super::IE7_VERSION,
        counter: full_before.counter + 1,
        config: after,
    };
    write_full(location, &full_after)?;

    Ok(())
}