import { useEffect, useRef, useCallback } from 'react'
import { p256 } from '@noble/curves/nist.js'
import { gcm } from '@noble/ciphers/aes.js'
import { randomBytes } from '@noble/ciphers/utils.js'
import { hkdf } from '@noble/hashes/hkdf.js'
import { sha256 } from '@noble/hashes/sha2.js'
function b64urlEncode(bytes) {
let str = ''
for (let i = 0; i < bytes.length; i++) str += String.fromCharCode(bytes[i])
return btoa(str).replace(/\+/g, '-').replace(/\//g, '_').replace(/=+$/, '')
}
function b64urlDecode(str) {
str = str.replace(/-/g, '+').replace(/_/g, '/')
while (str.length % 4) str += '='
const bin = atob(str)
const bytes = new Uint8Array(bin.length)
for (let i = 0; i < bin.length; i++) bytes[i] = bin.charCodeAt(i)
return bytes
}
function generateKeyPair() {
const privateKey = p256.utils.randomSecretKey()
const publicKey = p256.getPublicKey(privateKey, false) return { privateKey, publicKey }
}
function deriveAesKey(clientPrivateKey, serverPublicKeyBytes) {
const shared = p256.getSharedSecret(clientPrivateKey, serverPublicKeyBytes, false)
const sharedX = shared.slice(1, 33)
const info = new TextEncoder().encode('j-remote-aes256gcm')
return hkdf(sha256, sharedX, undefined, info, 32)
}
function aesEncrypt(key, plaintext) {
const nonce = randomBytes(12)
const encoded = new TextEncoder().encode(plaintext)
const cipher = gcm(key, nonce)
const ciphertext = cipher.encrypt(encoded)
const result = new Uint8Array(12 + ciphertext.length)
result.set(nonce, 0)
result.set(ciphertext, 12)
return result
}
function aesDecrypt(key, data) {
const nonce = data.slice(0, 12)
const ciphertext = data.slice(12)
const cipher = gcm(key, nonce)
const plaintext = cipher.decrypt(ciphertext)
return new TextDecoder().decode(plaintext)
}
export function useWebSocket(url, onMessage, onStatusChange) {
const wsRef = useRef(null)
const aesKeyRef = useRef(null)
const readyRef = useRef(false)
const pendingRef = useRef([])
const send = useCallback((data) => {
const ws = wsRef.current
const aesKey = aesKeyRef.current
if (!ws || ws.readyState !== WebSocket.OPEN) return
const json = JSON.stringify(data)
if (!readyRef.current || !aesKey) {
pendingRef.current.push(json)
return
}
try {
const buf = aesEncrypt(aesKey, json)
ws.send(buf.buffer)
} catch (err) {
console.error('加密发送失败', err)
}
}, [])
useEffect(() => {
let reconnectTimer = null
let pingInterval = null
let destroyed = false
function connect() {
if (destroyed) return
aesKeyRef.current = null
readyRef.current = false
pendingRef.current = []
const ws = new WebSocket(url)
ws.binaryType = 'arraybuffer'
wsRef.current = ws
ws.onopen = () => {
}
ws.onclose = () => {
onStatusChange(false)
readyRef.current = false
aesKeyRef.current = null
clearInterval(pingInterval)
if (!destroyed) {
reconnectTimer = setTimeout(connect, 1500)
}
}
ws.onerror = () => {}
ws.onmessage = (e) => {
try {
if (!readyRef.current) {
if (typeof e.data === 'string') {
const msg = JSON.parse(e.data)
if (msg.type === 'server_hello' && msg.server_pk) {
handleServerHello(ws, msg.server_pk)
return
}
}
if (e.data instanceof ArrayBuffer && aesKeyRef.current) {
const text = aesDecrypt(aesKeyRef.current, new Uint8Array(e.data))
const msg = JSON.parse(text)
if (msg.type === 'key_exchange_ok') {
readyRef.current = true
onStatusChange(true)
const pending = pendingRef.current.splice(0)
for (const json of pending) {
const buf = aesEncrypt(aesKeyRef.current, json)
if (ws.readyState === WebSocket.OPEN) ws.send(buf.buffer)
}
const syncBuf = aesEncrypt(aesKeyRef.current, JSON.stringify({ type: 'sync' }))
if (ws.readyState === WebSocket.OPEN) ws.send(syncBuf.buffer)
clearInterval(pingInterval)
pingInterval = setInterval(() => {
if (ws.readyState === WebSocket.OPEN && aesKeyRef.current && readyRef.current) {
try {
const buf = aesEncrypt(aesKeyRef.current, JSON.stringify({ type: 'ping' }))
ws.send(buf.buffer)
} catch {}
}
}, 10000)
return
}
}
return
}
if (e.data instanceof ArrayBuffer && aesKeyRef.current) {
const text = aesDecrypt(aesKeyRef.current, new Uint8Array(e.data))
const msg = JSON.parse(text)
onMessage(msg)
}
} catch (err) {
console.error('消息处理错误', err)
}
}
}
function handleServerHello(ws, serverPkB64) {
try {
const serverPkBytes = b64urlDecode(serverPkB64)
const { privateKey, publicKey } = generateKeyPair()
const clientPkB64 = b64urlEncode(publicKey)
ws.send(JSON.stringify({ type: 'key_exchange', client_pk: clientPkB64 }))
const aesKey = deriveAesKey(privateKey, serverPkBytes)
aesKeyRef.current = aesKey
} catch (err) {
console.error('密钥协商失败', err)
ws.close()
}
}
connect()
return () => {
destroyed = true
clearInterval(pingInterval)
clearTimeout(reconnectTimer)
wsRef.current?.close()
}
}, [url, onMessage, onStatusChange])
return send
}