wrpc 0.16.0

WebAssembly component-native RPC framework based on WIT
Documentation
package wrpc

import (
	"bufio"
	"bytes"
	"errors"
	"fmt"
	"io"
	"log/slog"
	"math"
)

type ByteStreamWriter struct {
	r     io.Reader
	chunk []byte
}

func (v *ByteStreamWriter) WriteTo(w ByteWriter) (err error) {
	if len(v.chunk) == 0 {
		v.chunk = make([]byte, 8096)
	}
	buf := bufio.NewWriter(w)
	defer func() {
		if fErr := buf.Flush(); fErr != nil {
			if err == nil {
				err = fmt.Errorf("failed to flush pending byte stream writer: %w", fErr)
			} else {
				slog.Warn("failed to flush pending byte stream writer", "err", fErr)
			}
		}
	}()
	for {
		var end bool
		slog.Debug("reading pending byte stream contents")
		n, err := v.r.Read(v.chunk)
		if err == io.EOF {
			end = true
			slog.Debug("pending byte stream reached EOF")
		} else if err != nil {
			return fmt.Errorf("failed to read pending byte stream chunk: %w", err)
		}
		if n > math.MaxUint32 {
			return fmt.Errorf("pending byte stream chunk length of %d overflows a 32-bit integer", n)
		}
		slog.Debug("writing pending byte stream chunk length", "len", n)
		_, err = WriteUint32(uint32(n), buf)
		if err != nil {
			return fmt.Errorf("failed to write pending byte stream chunk length of %d: %w", n, err)
		}
		_, err = buf.Write(v.chunk[:n])
		if err != nil {
			return fmt.Errorf("failed to write pending byte stream chunk contents: %w", err)
		}
		if end {
			if err := buf.WriteByte(0); err != nil {
				return fmt.Errorf("failed to write pending byte stream end byte: %w", err)
			}
			return nil
		}
	}
}

type ByteStreamReader struct {
	r   ByteReadCloser
	buf uint32
}

func (r *ByteStreamReader) Read(p []byte) (int, error) {
	n := r.buf
	if n == 0 {
		slog.Debug("reading pending byte stream chunk length")
		var err error
		n, err = ReadUint32(r.r)
		if err != nil {
			return 0, fmt.Errorf("failed to read pending byte stream chunk length: %w", err)
		}
		if n == 0 {
			return 0, io.EOF
		}
	}
	if len(p) > int(n) {
		p = p[:n]
	}
	slog.Debug("reading pending byte stream chunk contents", "len", n)
	rn, err := r.r.Read(p)
	if err != nil {
		return rn, fmt.Errorf("failed to read pending stream chunk bytes: %w", err)
	}
	if rn > int(n) {
		return rn, errors.New("read more bytes than requested")
	}
	r.buf = n - uint32(rn)
	return rn, nil
}

func (r *ByteStreamReader) Close() error {
	return r.r.Close()
}

func NewByteStreamReader(r ByteReadCloser) *ByteStreamReader {
	return &ByteStreamReader{
		r: r,
	}
}

// ReadStreamStatus reads a single byte from `r` and returns:
// - `true` if stream is "ready"
// - `false` if stream is "pending"
func ReadStreamStatus(r ByteReader) (bool, error) {
	status, err := r.ReadByte()
	if err != nil {
		return false, fmt.Errorf("failed to read `stream` status byte: %w", err)
	}
	switch status {
	case 0:
		return false, nil
	case 1:
		return true, nil
	default:
		return false, fmt.Errorf("invalid `stream` status byte %d", status)
	}
}

// ReadByteStream reads a stream of bytes from `r`
func ReadByteStream(r IndexReader, path ...uint32) (io.ReadCloser, error) {
	slog.Debug("reading byte stream status byte")
	ok, err := ReadStreamStatus(r)
	if err != nil {
		return nil, err
	}
	if !ok {
		r, err := r.Index(path...)
		if err != nil {
			return nil, fmt.Errorf("failed to index reader: %w", err)
		}
		return NewByteStreamReader(r), nil
	}
	slog.Debug("reading ready byte stream")
	buf, err := ReadByteList(r)
	if err != nil {
		return nil, fmt.Errorf("failed to read bytes: %w", err)
	}
	slog.Debug("read ready byte stream", "len", len(buf))
	return io.NopCloser(bytes.NewReader(buf)), nil
}

// ReadStream reads a stream from `r`
func ReadStream[T any](r IndexReader, f func(IndexReader) (T, error), path ...uint32) (Receiver[[]T], error) {
	slog.Debug("reading stream status byte")
	ok, err := ReadStreamStatus(r)
	if err != nil {
		return nil, err
	}
	if !ok {
		r, err := r.Index(path...)
		if err != nil {
			return nil, fmt.Errorf("failed to index reader: %w", err)
		}
		return NewDecodeReceiver(r, func(r IndexReadCloser) ([]T, error) {
			slog.Debug("reading pending stream chunk length")
			n, err := ReadUint32(r)
			if err != nil {
				return nil, fmt.Errorf("failed to read pending stream chunk length: %w", err)
			}
			if n == 0 {
				slog.Debug("pending stream EOF chunk received")
				return nil, io.EOF
			}
			vs := make([]T, n)
			for i := range vs {
				v, err := f(r)
				if err != nil {
					return nil, fmt.Errorf("failed to read pending stream chunk element %d: %w", i, err)
				}
				vs[i] = v
			}
			return vs, nil
		}), nil
	}
	slog.Debug("reading ready stream")
	vs, err := ReadList(r, f)
	if err != nil {
		return nil, fmt.Errorf("failed to read ready stream: %w", err)
	}
	slog.Debug("read ready stream", "len", len(vs))
	return NewCompleteReceiver(vs), nil
}

func WriteByteStream(r io.Reader, w IndexWriter, chunk []byte, path ...uint32) (err error) {
	slog.Debug("writing byte stream `stream::pending` status byte")
	if err := w.WriteByte(0); err != nil {
		return fmt.Errorf("failed to write `stream::pending` byte: %w", err)
	}
	wi, err := w.Index(path...)
	if err != nil {
		return fmt.Errorf("failed to index reader: %w", err)
	}
	s := &ByteStreamWriter{r, chunk}
	defer func() {
		if cErr := wi.Close(); cErr != nil {
			if err == nil {
				err = fmt.Errorf("failed to close pending byte stream: %w", cErr)
			} else {
				slog.Warn("failed to close pending byte stream", "err", cErr)
			}
		}
	}()
	if err := s.WriteTo(wi); err != nil {
		return fmt.Errorf("failed to write stream contents: %w", err)
	}
	return nil
}